mirror of
https://github.com/openai/codex.git
synced 2026-02-03 07:23:39 +00:00
Compare commits
28 Commits
custom-ins
...
interrupt-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9780b6d550 | ||
|
|
db54493c92 | ||
|
|
aabce31e84 | ||
|
|
78c6f0eb70 | ||
|
|
098462494e | ||
|
|
761ea58759 | ||
|
|
4a7b8aaace | ||
|
|
2e30a84c68 | ||
|
|
52d6655de9 | ||
|
|
3ef1f26ecc | ||
|
|
aad6dc1e4c | ||
|
|
aa4f9dff7a | ||
|
|
3baa5a73ae | ||
|
|
fb8622ac6a | ||
|
|
0b30945eef | ||
|
|
790c5ace10 | ||
|
|
7bcc77bb3c | ||
|
|
80bc428b37 | ||
|
|
9b3e1a8b56 | ||
|
|
666a546adc | ||
|
|
f90d91b1c3 | ||
|
|
b73b211ee5 | ||
|
|
2bb8d37b12 | ||
|
|
79825c08f1 | ||
|
|
4758897e6f | ||
|
|
6655653d77 | ||
|
|
df04fddbc4 | ||
|
|
47725f9fa8 |
@@ -21,7 +21,7 @@
|
||||
"settings": {
|
||||
"terminal.integrated.defaultProfile.linux": "bash"
|
||||
},
|
||||
"extensions": ["rust-lang.rust-analyzer", "tamasfe.even-better-toml"]
|
||||
"extensions": ["rust-lang.rust-analyzer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
12
.github/actions/codex/bun.lock
vendored
12
.github/actions/codex/bun.lock
vendored
@@ -8,8 +8,8 @@
|
||||
"@actions/github": "^6.0.1",
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/bun": "^1.2.19",
|
||||
"@types/node": "^24.1.0",
|
||||
"@types/bun": "^1.2.18",
|
||||
"@types/node": "^24.0.13",
|
||||
"prettier": "^3.6.2",
|
||||
"typescript": "^5.8.3",
|
||||
},
|
||||
@@ -48,15 +48,15 @@
|
||||
|
||||
"@octokit/types": ["@octokit/types@13.10.0", "", { "dependencies": { "@octokit/openapi-types": "^24.2.0" } }, "sha512-ifLaO34EbbPj0Xgro4G5lP5asESjwHracYJvVaPIyXMuiuXLlhic3S47cBdTb+jfODkTE5YtGCLt3Ay3+J97sA=="],
|
||||
|
||||
"@types/bun": ["@types/bun@1.2.19", "", { "dependencies": { "bun-types": "1.2.19" } }, "sha512-d9ZCmrH3CJ2uYKXQIUuZ/pUnTqIvLDS0SK7pFmbx8ma+ziH/FRMoAq5bYpRG7y+w1gl+HgyNZbtqgMq4W4e2Lg=="],
|
||||
"@types/bun": ["@types/bun@1.2.18", "", { "dependencies": { "bun-types": "1.2.18" } }, "sha512-Xf6RaWVheyemaThV0kUfaAUvCNokFr+bH8Jxp+tTZfx7dAPA8z9ePnP9S9+Vspzuxxx9JRAXhnyccRj3GyCMdQ=="],
|
||||
|
||||
"@types/node": ["@types/node@24.1.0", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-ut5FthK5moxFKH2T1CUOC6ctR67rQRvvHdFLCD2Ql6KXmMuCrjsSsRI9UsLCm9M18BMwClv4pn327UvB7eeO1w=="],
|
||||
"@types/node": ["@types/node@24.0.13", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-Qm9OYVOFHFYg3wJoTSrz80hoec5Lia/dPp84do3X7dZvLikQvM1YpmvTBEdIr/e+U8HTkFjLHLnl78K/qjf+jQ=="],
|
||||
|
||||
"@types/react": ["@types/react@19.1.8", "", { "dependencies": { "csstype": "^3.0.2" } }, "sha512-AwAfQ2Wa5bCx9WP8nZL2uMZWod7J7/JSplxbTmBQ5ms6QpqNYm672H0Vu9ZVKVngQ+ii4R/byguVEUZQyeg44g=="],
|
||||
|
||||
"before-after-hook": ["before-after-hook@2.2.3", "", {}, "sha512-NzUnlZexiaH/46WDhANlyR2bXRopNg4F/zuSA3OpZnllCUgRaOF2znDioDWrmbNVsuZk6l9pMquQB38cfBZwkQ=="],
|
||||
|
||||
"bun-types": ["bun-types@1.2.19", "", { "dependencies": { "@types/node": "*" }, "peerDependencies": { "@types/react": "^19" } }, "sha512-uAOTaZSPuYsWIXRpj7o56Let0g/wjihKCkeRqUBhlLVM/Bt+Fj9xTo+LhC1OV1XDaGkz4hNC80et5xgy+9KTHQ=="],
|
||||
"bun-types": ["bun-types@1.2.18", "", { "dependencies": { "@types/node": "*" }, "peerDependencies": { "@types/react": "^19" } }, "sha512-04+Eha5NP7Z0A9YgDAzMk5PHR16ZuLVa83b26kH5+cp1qZW4F6FmAURngE7INf4tKOvCE69vYvDEwoNl1tGiWw=="],
|
||||
|
||||
"csstype": ["csstype@3.1.3", "", {}, "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw=="],
|
||||
|
||||
@@ -82,8 +82,6 @@
|
||||
|
||||
"@octokit/plugin-rest-endpoint-methods/@octokit/types": ["@octokit/types@12.6.0", "", { "dependencies": { "@octokit/openapi-types": "^20.0.0" } }, "sha512-1rhSOfRa6H9w4YwK0yrf5faDaDTb+yLyBUKOCV4xtCDB5VmIPqd/v9yr9o6SAzOAlRxMiRiCic6JVM1/kunVkw=="],
|
||||
|
||||
"bun-types/@types/node": ["@types/node@24.0.13", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-Qm9OYVOFHFYg3wJoTSrz80hoec5Lia/dPp84do3X7dZvLikQvM1YpmvTBEdIr/e+U8HTkFjLHLnl78K/qjf+jQ=="],
|
||||
|
||||
"@octokit/plugin-paginate-rest/@octokit/types/@octokit/openapi-types": ["@octokit/openapi-types@20.0.0", "", {}, "sha512-EtqRBEjp1dL/15V7WiX5LJMIxxkdiGJnabzYx5Apx4FkQIFgAfKumXeYAqqJCj1s+BMX4cPFIFC4OLCR6stlnA=="],
|
||||
|
||||
"@octokit/plugin-rest-endpoint-methods/@octokit/types/@octokit/openapi-types": ["@octokit/openapi-types@20.0.0", "", {}, "sha512-EtqRBEjp1dL/15V7WiX5LJMIxxkdiGJnabzYx5Apx4FkQIFgAfKumXeYAqqJCj1s+BMX4cPFIFC4OLCR6stlnA=="],
|
||||
|
||||
4
.github/actions/codex/package.json
vendored
4
.github/actions/codex/package.json
vendored
@@ -13,8 +13,8 @@
|
||||
"@actions/github": "^6.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/bun": "^1.2.19",
|
||||
"@types/node": "^24.1.0",
|
||||
"@types/bun": "^1.2.18",
|
||||
"@types/node": "^24.0.13",
|
||||
"prettier": "^3.6.2",
|
||||
"typescript": "^5.8.3"
|
||||
}
|
||||
|
||||
2
.github/workflows/rust-release.yml
vendored
2
.github/workflows/rust-release.yml
vendored
@@ -93,7 +93,7 @@ jobs:
|
||||
sudo apt install -y musl-tools pkg-config
|
||||
|
||||
- name: Cargo build
|
||||
run: cargo build --target ${{ matrix.target }} --release --bin codex --bin codex-exec --bin codex-linux-sandbox
|
||||
run: cargo build --target ${{ matrix.target }} --release --all-targets --all-features
|
||||
|
||||
- name: Stage artifacts
|
||||
shell: bash
|
||||
|
||||
@@ -370,26 +370,11 @@ export function isSafeCommand(
|
||||
reason: "View file with line numbers",
|
||||
group: "Reading files",
|
||||
};
|
||||
case "rg": {
|
||||
// Certain ripgrep options execute external commands or invoke other
|
||||
// processes, so we must reject them.
|
||||
const isUnsafe = command.some(
|
||||
(arg: string) =>
|
||||
UNSAFE_OPTIONS_FOR_RIPGREP_WITHOUT_ARGS.has(arg) ||
|
||||
[...UNSAFE_OPTIONS_FOR_RIPGREP_WITH_ARGS].some(
|
||||
(opt) => arg === opt || arg.startsWith(`${opt}=`),
|
||||
),
|
||||
);
|
||||
|
||||
if (isUnsafe) {
|
||||
break;
|
||||
}
|
||||
|
||||
case "rg":
|
||||
return {
|
||||
reason: "Ripgrep search",
|
||||
group: "Searching",
|
||||
};
|
||||
}
|
||||
case "find": {
|
||||
// Certain options to `find` allow executing arbitrary processes, so we
|
||||
// cannot auto-approve them.
|
||||
@@ -510,22 +495,6 @@ const UNSAFE_OPTIONS_FOR_FIND_COMMAND: ReadonlySet<string> = new Set([
|
||||
"-fprintf",
|
||||
]);
|
||||
|
||||
// Ripgrep options that are considered unsafe because they may execute
|
||||
// arbitrary commands or spawn auxiliary processes.
|
||||
const UNSAFE_OPTIONS_FOR_RIPGREP_WITH_ARGS: ReadonlySet<string> = new Set([
|
||||
// Executes an arbitrary command for each matching file.
|
||||
"--pre",
|
||||
// Allows custom hostname command which could leak environment details.
|
||||
"--hostname-bin",
|
||||
]);
|
||||
|
||||
const UNSAFE_OPTIONS_FOR_RIPGREP_WITHOUT_ARGS: ReadonlySet<string> = new Set([
|
||||
// Enables searching inside archives which triggers external decompression
|
||||
// utilities – reject out of an abundance of caution.
|
||||
"--search-zip",
|
||||
"-z",
|
||||
]);
|
||||
|
||||
// ---------------- Helper utilities for complex shell expressions -----------------
|
||||
|
||||
// A conservative allow-list of bash operators that do not, on their own, cause
|
||||
|
||||
@@ -44,14 +44,6 @@ describe("canAutoApprove()", () => {
|
||||
group: "Navigating",
|
||||
runInSandbox: false,
|
||||
});
|
||||
|
||||
// Ripgrep safe invocation.
|
||||
expect(check(["rg", "TODO"])).toEqual({
|
||||
type: "auto-approve",
|
||||
reason: "Ripgrep search",
|
||||
group: "Searching",
|
||||
runInSandbox: false,
|
||||
});
|
||||
});
|
||||
|
||||
test("simple safe commands within a `bash -lc` call", () => {
|
||||
@@ -75,24 +67,6 @@ describe("canAutoApprove()", () => {
|
||||
});
|
||||
});
|
||||
|
||||
test("ripgrep unsafe flags", () => {
|
||||
// Flags that do not take arguments
|
||||
expect(check(["rg", "--search-zip", "TODO"])).toEqual({ type: "ask-user" });
|
||||
expect(check(["rg", "-z", "TODO"])).toEqual({ type: "ask-user" });
|
||||
|
||||
// Flags that take arguments (provided separately)
|
||||
expect(check(["rg", "--pre", "cat", "TODO"])).toEqual({ type: "ask-user" });
|
||||
expect(check(["rg", "--hostname-bin", "hostname", "TODO"])).toEqual({
|
||||
type: "ask-user",
|
||||
});
|
||||
|
||||
// Flags that take arguments in = form
|
||||
expect(check(["rg", "--pre=cat", "TODO"])).toEqual({ type: "ask-user" });
|
||||
expect(check(["rg", "--hostname-bin=hostname", "TODO"])).toEqual({
|
||||
type: "ask-user",
|
||||
});
|
||||
});
|
||||
|
||||
test("bash -lc commands with unsafe redirects", () => {
|
||||
expect(check(["bash", "-lc", "echo hello > file.txt"])).toEqual({
|
||||
type: "ask-user",
|
||||
|
||||
214
codex-rs/Cargo.lock
generated
214
codex-rs/Cargo.lock
generated
@@ -463,18 +463,18 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
||||
|
||||
[[package]]
|
||||
name = "castaway"
|
||||
version = "0.2.4"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
|
||||
checksum = "0abae9be0aaf9ea96a3b1b8b1b55c602ca751eba1b1500220cea4ecbafe7c0d5"
|
||||
dependencies = [
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.30"
|
||||
version = "1.2.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7"
|
||||
checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
@@ -570,9 +570,9 @@ checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
||||
|
||||
[[package]]
|
||||
name = "clipboard-win"
|
||||
version = "5.4.1"
|
||||
version = "5.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4"
|
||||
checksum = "15efe7a882b08f34e38556b14f2fb3daa98769d06c7f0c1b076dfd0d983bc892"
|
||||
dependencies = [
|
||||
"error-code",
|
||||
]
|
||||
@@ -605,18 +605,6 @@ dependencies = [
|
||||
"tree-sitter-bash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-arg0"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"codex-apply-patch",
|
||||
"codex-core",
|
||||
"codex-linux-sandbox",
|
||||
"dotenvy",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-chatgpt"
|
||||
version = "0.0.0"
|
||||
@@ -640,11 +628,11 @@ dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"codex-arg0",
|
||||
"codex-chatgpt",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-exec",
|
||||
"codex-linux-sandbox",
|
||||
"codex-login",
|
||||
"codex-mcp-server",
|
||||
"codex-tui",
|
||||
@@ -661,7 +649,7 @@ dependencies = [
|
||||
"clap",
|
||||
"codex-core",
|
||||
"serde",
|
||||
"toml 0.9.2",
|
||||
"toml 0.9.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -675,41 +663,37 @@ dependencies = [
|
||||
"bytes",
|
||||
"codex-apply-patch",
|
||||
"codex-mcp-client",
|
||||
"core_test_support",
|
||||
"dirs",
|
||||
"env-flags",
|
||||
"eventsource-stream",
|
||||
"fs2",
|
||||
"futures",
|
||||
"landlock",
|
||||
"libc",
|
||||
"maplit",
|
||||
"mcp-types",
|
||||
"mime_guess",
|
||||
"openssl-sys",
|
||||
"predicates",
|
||||
"pretty_assertions",
|
||||
"rand 0.9.2",
|
||||
"rand 0.9.1",
|
||||
"reqwest",
|
||||
"seccompiler",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"shlex",
|
||||
"strum_macros 0.27.2",
|
||||
"strum_macros 0.27.1",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"toml 0.9.2",
|
||||
"toml 0.9.1",
|
||||
"tracing",
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
"whoami",
|
||||
"wildmatch",
|
||||
"wiremock",
|
||||
]
|
||||
@@ -719,17 +703,14 @@ name = "codex-exec"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-linux-sandbox",
|
||||
"owo-colors",
|
||||
"predicates",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
@@ -774,7 +755,6 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"landlock",
|
||||
"libc",
|
||||
@@ -812,24 +792,17 @@ name = "codex-mcp-server"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-arg0",
|
||||
"codex-core",
|
||||
"codex-linux-sandbox",
|
||||
"mcp-types",
|
||||
"mcp_test_support",
|
||||
"pretty_assertions",
|
||||
"schemars 0.8.22",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"toml 0.9.2",
|
||||
"toml 0.9.1",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -840,10 +813,10 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"clap",
|
||||
"codex-ansi-escape",
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-file-search",
|
||||
"codex-linux-sandbox",
|
||||
"codex-login",
|
||||
"color-eyre",
|
||||
"crossterm",
|
||||
@@ -858,9 +831,8 @@ dependencies = [
|
||||
"regex-lite",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"strum 0.27.2",
|
||||
"strum_macros 0.27.2",
|
||||
"tempfile",
|
||||
"strum 0.27.1",
|
||||
"strum_macros 0.27.1",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
@@ -869,7 +841,6 @@ dependencies = [
|
||||
"tui-markdown",
|
||||
"tui-textarea",
|
||||
"unicode-segmentation",
|
||||
"unicode-width 0.1.14",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
@@ -972,16 +943,6 @@ version = "0.8.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||
|
||||
[[package]]
|
||||
name = "core_test_support"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"codex-core",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.17"
|
||||
@@ -993,9 +954,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.5.0"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
|
||||
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
@@ -1304,12 +1265,6 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
|
||||
|
||||
[[package]]
|
||||
name = "dotenvy"
|
||||
version = "0.15.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||
|
||||
[[package]]
|
||||
name = "dupe"
|
||||
version = "0.9.1"
|
||||
@@ -1542,7 +1497,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"rustix 1.0.8",
|
||||
"rustix 1.0.7",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -1991,9 +1946,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.16"
|
||||
version = "0.1.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e"
|
||||
checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
@@ -2007,7 +1962,7 @@ dependencies = [
|
||||
"libc",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2 0.6.0",
|
||||
"socket2",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
@@ -2260,9 +2215,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "instability"
|
||||
version = "0.3.9"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "435d80800b936787d62688c927b6490e887c7ef5ff9ce922c6c6050fca75eb9a"
|
||||
checksum = "0bf9fed6d91cfb734e7476a06bde8300a1b94e217e1b523b6f0cd1a01998c71d"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"indoc",
|
||||
@@ -2293,9 +2248,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "io-uring"
|
||||
version = "0.7.9"
|
||||
version = "0.7.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4"
|
||||
checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"cfg-if",
|
||||
@@ -2499,9 +2454,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.1.6"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4488594b9328dee448adb906d8b126d9b7deb7cf5c22161ee591610bb1be83c0"
|
||||
checksum = "1580801010e535496706ba011c15f8532df6b42297d2e471fec38ceadd8c0638"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"libc",
|
||||
@@ -2634,22 +2589,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp_test_support"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-mcp-server",
|
||||
"mcp-types",
|
||||
"pretty_assertions",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.5"
|
||||
@@ -3325,9 +3264,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.9.2"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
||||
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
|
||||
dependencies = [
|
||||
"rand_chacha 0.9.0",
|
||||
"rand_core 0.9.3",
|
||||
@@ -3374,7 +3313,8 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "ratatui"
|
||||
version = "0.29.0"
|
||||
source = "git+https://github.com/nornagon/ratatui?branch=nornagon-v0.29.0-patch#bca287ddc5d38fe088c79e2eda22422b96226f2e"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"cassowary",
|
||||
@@ -3479,9 +3419,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.15"
|
||||
version = "0.5.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e8af0dde094006011e6a740d4879319439489813bd0bcdc7d821beaeeff48ec"
|
||||
checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
]
|
||||
@@ -3629,9 +3569,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rgb"
|
||||
version = "0.8.52"
|
||||
version = "0.8.51"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c6a884d2998352bb4daf0183589aec883f16a6da1f4dde84d8e2e9a5409a1ce"
|
||||
checksum = "a457e416a0f90d246a4c3288bd7a25b2304ca727f253f95be383dd17af56be8f"
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
@@ -3707,22 +3647,22 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.0.8"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8"
|
||||
checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.9.4",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.29"
|
||||
version = "0.23.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1"
|
||||
checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"rustls-pki-types",
|
||||
@@ -3742,9 +3682,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.103.4"
|
||||
version = "0.103.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc"
|
||||
checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
@@ -3970,9 +3910,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.141"
|
||||
version = "1.0.140"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3"
|
||||
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
|
||||
dependencies = [
|
||||
"indexmap 2.10.0",
|
||||
"itoa",
|
||||
@@ -4165,16 +4105,6 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
version = "1.2.0"
|
||||
@@ -4318,9 +4248,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.27.2"
|
||||
version = "0.27.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf"
|
||||
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
@@ -4337,13 +4267,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
version = "0.27.2"
|
||||
version = "0.27.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7"
|
||||
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
@@ -4466,7 +4397,7 @@ dependencies = [
|
||||
"fastrand",
|
||||
"getrandom 0.3.3",
|
||||
"once_cell",
|
||||
"rustix 1.0.8",
|
||||
"rustix 1.0.7",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -4487,7 +4418,7 @@ version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed"
|
||||
dependencies = [
|
||||
"rustix 1.0.8",
|
||||
"rustix 1.0.7",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -4633,7 +4564,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"slab",
|
||||
"socket2 0.5.10",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
@@ -4720,9 +4651,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.9.2"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed0aee96c12fa71097902e0bb061a5e1ebd766a6636bb605ba401c45c1650eac"
|
||||
checksum = "0207d6ed1852c2a124c1fbec61621acb8330d2bf969a5d0643131e9affd985a5"
|
||||
dependencies = [
|
||||
"indexmap 2.10.0",
|
||||
"serde",
|
||||
@@ -4766,18 +4697,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_parser"
|
||||
version = "1.0.1"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97200572db069e74c512a14117b296ba0a80a30123fbbb5aa1f4a348f639ca30"
|
||||
checksum = "b5c1c469eda89749d2230d8156a5969a69ffe0d6d01200581cdc6110674d293e"
|
||||
dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_writer"
|
||||
version = "1.0.2"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcc842091f2def52017664b53082ecbbeb5c7731092bad69d2c63050401dfd64"
|
||||
checksum = "b679217f2848de74cabd3e8fc5e6d66f40b7da40f8e1954d92054d9010690fd5"
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
@@ -4910,9 +4841,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.25.8"
|
||||
version = "0.25.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d7b8994f367f16e6fa14b5aebbcb350de5d7cbea82dc5b00ae997dd71680dd2"
|
||||
checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
@@ -5154,12 +5085,6 @@ dependencies = [
|
||||
"wit-bindgen-rt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasite"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.100"
|
||||
@@ -5260,17 +5185,6 @@ version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6994d13118ab492c3c80c1f81928718159254c53c472bf9ce36f8dae4add02a7"
|
||||
dependencies = [
|
||||
"redox_syscall",
|
||||
"wasite",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wildmatch"
|
||||
version = "2.4.0"
|
||||
@@ -5599,9 +5513,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.7.12"
|
||||
version = "0.7.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95"
|
||||
checksum = "74c7b26e3480b707944fc872477815d29a8e429d2f93a1ce000f5fa84a15cbcd"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
@@ -3,7 +3,6 @@ resolver = "2"
|
||||
members = [
|
||||
"ansi-escape",
|
||||
"apply-patch",
|
||||
"arg0",
|
||||
"cli",
|
||||
"common",
|
||||
"core",
|
||||
@@ -41,8 +40,3 @@ strip = "symbols"
|
||||
|
||||
# See https://github.com/openai/codex/issues/1411 for details.
|
||||
codegen-units = 1
|
||||
|
||||
[patch.crates-io]
|
||||
# ratatui = { path = "../../ratatui" }
|
||||
ratatui = { git = "https://github.com/nornagon/ratatui", branch = "nornagon-v0.29.0-patch" }
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ workspace = true
|
||||
anyhow = "1"
|
||||
similar = "2.7.0"
|
||||
thiserror = "2.0.12"
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter = "0.25.3"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
[package]
|
||||
name = "codex-arg0"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
name = "codex_arg0"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
codex-apply-patch = { path = "../apply-patch" }
|
||||
codex-core = { path = "../core" }
|
||||
codex-linux-sandbox = { path = "../linux-sandbox" }
|
||||
dotenvy = "0.15.7"
|
||||
tokio = { version = "1", features = ["rt-multi-thread"] }
|
||||
@@ -1,89 +0,0 @@
|
||||
use std::future::Future;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// While we want to deploy the Codex CLI as a single executable for simplicity,
|
||||
/// we also want to expose some of its functionality as distinct CLIs, so we use
|
||||
/// the "arg0 trick" to determine which CLI to dispatch. This effectively allows
|
||||
/// us to simulate deploying multiple executables as a single binary on Mac and
|
||||
/// Linux (but not Windows).
|
||||
///
|
||||
/// When the current executable is invoked through the hard-link or alias named
|
||||
/// `codex-linux-sandbox` we *directly* execute
|
||||
/// [`codex_linux_sandbox::run_main`] (which never returns). Otherwise we:
|
||||
///
|
||||
/// 1. Use [`dotenvy::from_path`] and [`dotenvy::dotenv`] to modify the
|
||||
/// environment before creating any threads.
|
||||
/// 2. Construct a Tokio multi-thread runtime.
|
||||
/// 3. Derive the path to the current executable (so children can re-invoke the
|
||||
/// sandbox) when running on Linux.
|
||||
/// 4. Execute the provided async `main_fn` inside that runtime, forwarding any
|
||||
/// error. Note that `main_fn` receives `codex_linux_sandbox_exe:
|
||||
/// Option<PathBuf>`, as an argument, which is generally needed as part of
|
||||
/// constructing [`codex_core::config::Config`].
|
||||
///
|
||||
/// This function should be used to wrap any `main()` function in binary crates
|
||||
/// in this workspace that depends on these helper CLIs.
|
||||
pub fn arg0_dispatch_or_else<F, Fut>(main_fn: F) -> anyhow::Result<()>
|
||||
where
|
||||
F: FnOnce(Option<PathBuf>) -> Fut,
|
||||
Fut: Future<Output = anyhow::Result<()>>,
|
||||
{
|
||||
// Determine if we were invoked via the special alias.
|
||||
let mut args = std::env::args_os();
|
||||
let argv0 = args.next().unwrap_or_default();
|
||||
let exe_name = Path::new(&argv0)
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if exe_name == "codex-linux-sandbox" {
|
||||
// Safety: [`run_main`] never returns.
|
||||
codex_linux_sandbox::run_main();
|
||||
}
|
||||
|
||||
let argv1 = args.next().unwrap_or_default();
|
||||
if argv1 == "--codex-run-as-apply-patch" {
|
||||
let patch_arg = args.next().and_then(|s| s.to_str().map(|s| s.to_owned()));
|
||||
let exit_code = match patch_arg {
|
||||
Some(patch_arg) => {
|
||||
let mut stdout = std::io::stdout();
|
||||
let mut stderr = std::io::stderr();
|
||||
match codex_apply_patch::apply_patch(&patch_arg, &mut stdout, &mut stderr) {
|
||||
Ok(()) => 0,
|
||||
Err(_) => 1,
|
||||
}
|
||||
}
|
||||
None => {
|
||||
eprintln!("Error: --codex-run-as-apply-patch requires a UTF-8 PATCH argument.");
|
||||
1
|
||||
}
|
||||
};
|
||||
std::process::exit(exit_code);
|
||||
}
|
||||
|
||||
// This modifies the environment, which is not thread-safe, so do this
|
||||
// before creating any threads/the Tokio runtime.
|
||||
load_dotenv();
|
||||
|
||||
// Regular invocation – create a Tokio runtime and execute the provided
|
||||
// async entry-point.
|
||||
let runtime = tokio::runtime::Runtime::new()?;
|
||||
runtime.block_on(async move {
|
||||
let codex_linux_sandbox_exe: Option<PathBuf> = if cfg!(target_os = "linux") {
|
||||
std::env::current_exe().ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
main_fn(codex_linux_sandbox_exe).await
|
||||
})
|
||||
}
|
||||
|
||||
/// Load env vars from ~/.codex/.env and `$(pwd)/.env`.
|
||||
fn load_dotenv() {
|
||||
if let Ok(codex_home) = codex_core::config::find_codex_home() {
|
||||
dotenvy::from_path(codex_home.join(".env")).ok();
|
||||
}
|
||||
dotenvy::dotenv().ok();
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::config::Config;
|
||||
@@ -19,10 +17,7 @@ pub struct ApplyCommand {
|
||||
#[clap(flatten)]
|
||||
pub config_overrides: CliConfigOverrides,
|
||||
}
|
||||
pub async fn run_apply_command(
|
||||
apply_cli: ApplyCommand,
|
||||
cwd: Option<PathBuf>,
|
||||
) -> anyhow::Result<()> {
|
||||
pub async fn run_apply_command(apply_cli: ApplyCommand) -> anyhow::Result<()> {
|
||||
let config = Config::load_with_cli_overrides(
|
||||
apply_cli
|
||||
.config_overrides
|
||||
@@ -34,13 +29,10 @@ pub async fn run_apply_command(
|
||||
init_chatgpt_token_from_auth(&config.codex_home).await?;
|
||||
|
||||
let task_response = get_task(&config, apply_cli.task_id).await?;
|
||||
apply_diff_from_task(task_response, cwd).await
|
||||
apply_diff_from_task(task_response).await
|
||||
}
|
||||
|
||||
pub async fn apply_diff_from_task(
|
||||
task_response: GetTaskResponse,
|
||||
cwd: Option<PathBuf>,
|
||||
) -> anyhow::Result<()> {
|
||||
pub async fn apply_diff_from_task(task_response: GetTaskResponse) -> anyhow::Result<()> {
|
||||
let diff_turn = match task_response.current_diff_task_turn {
|
||||
Some(turn) => turn,
|
||||
None => anyhow::bail!("No diff turn found"),
|
||||
@@ -50,17 +42,13 @@ pub async fn apply_diff_from_task(
|
||||
_ => None,
|
||||
});
|
||||
match output_diff {
|
||||
Some(output_diff) => apply_diff(&output_diff.diff, cwd).await,
|
||||
Some(output_diff) => apply_diff(&output_diff.diff).await,
|
||||
None => anyhow::bail!("No PR output item found"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_diff(diff: &str, cwd: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let mut cmd = tokio::process::Command::new("git");
|
||||
if let Some(cwd) = cwd {
|
||||
cmd.current_dir(cwd);
|
||||
}
|
||||
let toplevel_output = cmd
|
||||
async fn apply_diff(diff: &str) -> anyhow::Result<()> {
|
||||
let toplevel_output = tokio::process::Command::new("git")
|
||||
.args(vec!["rev-parse", "--show-toplevel"])
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
@@ -78,7 +78,17 @@ async fn test_apply_command_creates_fibonacci_file() {
|
||||
.await
|
||||
.expect("Failed to load fixture");
|
||||
|
||||
apply_diff_from_task(task_response, Some(repo_path.to_path_buf()))
|
||||
let original_dir = std::env::current_dir().expect("Failed to get current dir");
|
||||
std::env::set_current_dir(repo_path).expect("Failed to change directory");
|
||||
struct DirGuard(std::path::PathBuf);
|
||||
impl Drop for DirGuard {
|
||||
fn drop(&mut self) {
|
||||
let _ = std::env::set_current_dir(&self.0);
|
||||
}
|
||||
}
|
||||
let _guard = DirGuard(original_dir);
|
||||
|
||||
apply_diff_from_task(task_response)
|
||||
.await
|
||||
.expect("Failed to apply diff from task");
|
||||
|
||||
@@ -163,7 +173,7 @@ console.log(fib(10));
|
||||
.await
|
||||
.expect("Failed to load fixture");
|
||||
|
||||
let apply_result = apply_diff_from_task(task_response, Some(repo_path.to_path_buf())).await;
|
||||
let apply_result = apply_diff_from_task(task_response).await;
|
||||
|
||||
assert!(
|
||||
apply_result.is_err(),
|
||||
|
||||
@@ -18,12 +18,12 @@ workspace = true
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
clap_complete = "4"
|
||||
codex-arg0 = { path = "../arg0" }
|
||||
codex-chatgpt = { path = "../chatgpt" }
|
||||
codex-core = { path = "../core" }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-exec = { path = "../exec" }
|
||||
codex-login = { path = "../login" }
|
||||
codex-linux-sandbox = { path = "../linux-sandbox" }
|
||||
codex-mcp-server = { path = "../mcp-server" }
|
||||
codex-tui = { path = "../tui" }
|
||||
serde_json = "1"
|
||||
|
||||
@@ -2,7 +2,6 @@ use clap::CommandFactory;
|
||||
use clap::Parser;
|
||||
use clap_complete::Shell;
|
||||
use clap_complete::generate;
|
||||
use codex_arg0::arg0_dispatch_or_else;
|
||||
use codex_chatgpt::apply_command::ApplyCommand;
|
||||
use codex_chatgpt::apply_command::run_apply_command;
|
||||
use codex_cli::LandlockCommand;
|
||||
@@ -93,7 +92,7 @@ struct LoginCommand {
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
codex_linux_sandbox::run_with_sandbox(|codex_linux_sandbox_exe| async move {
|
||||
cli_main(codex_linux_sandbox_exe).await?;
|
||||
Ok(())
|
||||
})
|
||||
@@ -106,8 +105,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
None => {
|
||||
let mut tui_cli = cli.interactive;
|
||||
prepend_config_flags(&mut tui_cli.config_overrides, cli.config_overrides);
|
||||
let usage = codex_tui::run_main(tui_cli, codex_linux_sandbox_exe)?;
|
||||
println!("{}", codex_core::protocol::FinalOutput::from(usage));
|
||||
codex_tui::run_main(tui_cli, codex_linux_sandbox_exe)?;
|
||||
}
|
||||
Some(Subcommand::Exec(mut exec_cli)) => {
|
||||
prepend_config_flags(&mut exec_cli.config_overrides, cli.config_overrides);
|
||||
@@ -147,7 +145,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
},
|
||||
Some(Subcommand::Apply(mut apply_cli)) => {
|
||||
prepend_config_flags(&mut apply_cli.config_overrides, cli.config_overrides);
|
||||
run_apply_command(apply_cli, None).await?;
|
||||
run_apply_command(apply_cli).await?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::sync::Arc;
|
||||
use clap::Parser;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::Codex;
|
||||
use codex_core::CodexSpawnOk;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::protocol::Submission;
|
||||
@@ -36,7 +35,7 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> {
|
||||
|
||||
let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?;
|
||||
let ctrl_c = notify_on_sigint();
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await?;
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await?;
|
||||
let codex = Arc::new(codex);
|
||||
|
||||
// Task that reads JSON lines from stdin and forwards to Submission Queue
|
||||
|
||||
@@ -64,11 +64,7 @@ impl CliConfigOverrides {
|
||||
// `-c model=o3` without the quotes.
|
||||
let value: Value = match parse_toml_value(value_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
// Strip leading/trailing quotes if present
|
||||
let trimmed = value_str.trim().trim_matches(|c| c == '"' || c == '\'');
|
||||
Value::String(trimmed.to_string())
|
||||
}
|
||||
Err(_) => Value::String(value_str.to_string()),
|
||||
};
|
||||
|
||||
Ok((key.to_string(), value))
|
||||
|
||||
@@ -498,5 +498,14 @@ Options that are specific to the TUI.
|
||||
|
||||
```toml
|
||||
[tui]
|
||||
# More to come here
|
||||
# This will make it so that Codex does not try to process mouse events, which
|
||||
# means your Terminal's native drag-to-text to text selection and copy/paste
|
||||
# should work. The tradeoff is that Codex will not receive any mouse events, so
|
||||
# it will not be possible to use the mouse to scroll conversation history.
|
||||
#
|
||||
# Note that most terminals support holding down a modifier key when using the
|
||||
# mouse to support text selection. For example, even if Codex mouse capture is
|
||||
# enabled (i.e., this is set to `false`), you can still hold down alt while
|
||||
# dragging the mouse to select text.
|
||||
disable_mouse_capture = true # defaults to `false`
|
||||
```
|
||||
|
||||
@@ -22,7 +22,6 @@ env-flags = "0.1.1"
|
||||
eventsource-stream = "0.2.3"
|
||||
fs2 = "0.4.3"
|
||||
futures = "0.3"
|
||||
libc = "0.2.174"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
mime_guess = "2.0"
|
||||
rand = "0.9"
|
||||
@@ -30,8 +29,7 @@ reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
strum_macros = "0.27.2"
|
||||
strum_macros = "0.27.1"
|
||||
thiserror = "2.0.12"
|
||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||
tokio = { version = "1", features = [
|
||||
@@ -42,14 +40,12 @@ tokio = { version = "1", features = [
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = "0.7.14"
|
||||
toml = "0.9.2"
|
||||
toml = "0.9.1"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter = "0.25.3"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
wildmatch = "2.4.0"
|
||||
whoami = "1.6.0"
|
||||
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
landlock = "0.4.1"
|
||||
@@ -65,7 +61,6 @@ openssl-sys = { version = "*", features = ["vendored"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
core_test_support = { path = "tests/common" }
|
||||
maplit = "1.0.2"
|
||||
predicates = "3"
|
||||
pretty_assertions = "1.4.1"
|
||||
|
||||
@@ -2,18 +2,9 @@
|
||||
|
||||
This crate implements the business logic for Codex. It is designed to be used by the various Codex UIs written in Rust.
|
||||
|
||||
## Dependencies
|
||||
Though for non-Rust UIs, we are also working to define a _protocol_ for talking to Codex. See:
|
||||
|
||||
Note that `codex-core` makes some assumptions about certain helper utilities being available in the environment. Currently, this
|
||||
- [Specification](../docs/protocol_v1.md)
|
||||
- [Rust types](./src/protocol.rs)
|
||||
|
||||
### macOS
|
||||
|
||||
Expects `/usr/bin/sandbox-exec` to be present.
|
||||
|
||||
### Linux
|
||||
|
||||
Expects the binary containing `codex-core` to run the equivalent of `codex debug landlock` when `arg0` is `codex-linux-sandbox`. See the `codex-arg0` crate for details.
|
||||
|
||||
### All Platforms
|
||||
|
||||
Expects the binary containing `codex-core` to simulate the virtual `apply_patch` CLI when `arg1` is `--codex-run-as-apply-patch`. See the `codex-arg0` crate for details.
|
||||
You can use the `proto` subcommand using the executable in the [`cli` crate](../cli) to speak the protocol using newline-delimited-JSON over stdin/stdout.
|
||||
|
||||
@@ -1,406 +0,0 @@
|
||||
use crate::codex::Session;
|
||||
use crate::models::FunctionCallOutputPayload;
|
||||
use crate::models::ResponseInputItem;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::FileChange;
|
||||
use crate::protocol::PatchApplyBeginEvent;
|
||||
use crate::protocol::PatchApplyEndEvent;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::safety::assess_patch_safety;
|
||||
use anyhow::Context;
|
||||
use codex_apply_patch::AffectedPaths;
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::ApplyPatchFileChange;
|
||||
use codex_apply_patch::print_summary;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub(crate) async fn apply_patch(
|
||||
sess: &Session,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
action: ApplyPatchAction,
|
||||
) -> ResponseInputItem {
|
||||
let writable_roots_snapshot = {
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let guard = sess.writable_roots.lock().unwrap();
|
||||
guard.clone()
|
||||
};
|
||||
|
||||
let auto_approved = match assess_patch_safety(
|
||||
&action,
|
||||
sess.approval_policy,
|
||||
&writable_roots_snapshot,
|
||||
&sess.cwd,
|
||||
) {
|
||||
SafetyCheck::AutoApprove { .. } => true,
|
||||
SafetyCheck::AskUser => {
|
||||
// Compute a readable summary of path changes to include in the
|
||||
// approval request so the user can make an informed decision.
|
||||
let rx_approve = sess
|
||||
.request_patch_approval(sub_id.clone(), call_id.clone(), &action, None, None)
|
||||
.await;
|
||||
match rx_approve.await.unwrap_or_default() {
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => false,
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "patch rejected by user".to_string(),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
SafetyCheck::Reject { reason } => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("patch rejected: {reason}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Verify write permissions before touching the filesystem.
|
||||
let writable_snapshot = {
|
||||
#[allow(clippy::unwrap_used)]
|
||||
sess.writable_roots.lock().unwrap().clone()
|
||||
};
|
||||
|
||||
if let Some(offending) = first_offending_path(&action, &writable_snapshot, &sess.cwd) {
|
||||
let root = offending.parent().unwrap_or(&offending).to_path_buf();
|
||||
|
||||
let reason = Some(format!(
|
||||
"grant write access to {} for this session",
|
||||
root.display()
|
||||
));
|
||||
|
||||
let rx = sess
|
||||
.request_patch_approval(
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
&action,
|
||||
reason.clone(),
|
||||
Some(root.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
if !matches!(
|
||||
rx.await.unwrap_or_default(),
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession
|
||||
) {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "patch rejected by user".to_string(),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// user approved, extend writable roots for this session
|
||||
#[allow(clippy::unwrap_used)]
|
||||
sess.writable_roots.lock().unwrap().push(root);
|
||||
}
|
||||
|
||||
let _ = sess
|
||||
.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
auto_approved,
|
||||
changes: convert_apply_patch_to_protocol(&action),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut stdout = Vec::new();
|
||||
let mut stderr = Vec::new();
|
||||
// Enforce writable roots. If a write is blocked, collect offending root
|
||||
// and prompt the user to extend permissions.
|
||||
let mut result = apply_changes_from_apply_patch_and_report(&action, &mut stdout, &mut stderr);
|
||||
|
||||
if let Err(err) = &result {
|
||||
if err.kind() == std::io::ErrorKind::PermissionDenied {
|
||||
// Determine first offending path.
|
||||
let offending_opt = action
|
||||
.changes()
|
||||
.iter()
|
||||
.flat_map(|(path, change)| match change {
|
||||
ApplyPatchFileChange::Add { .. } => vec![path.as_ref()],
|
||||
ApplyPatchFileChange::Delete => vec![path.as_ref()],
|
||||
ApplyPatchFileChange::Update {
|
||||
move_path: Some(move_path),
|
||||
..
|
||||
} => {
|
||||
vec![path.as_ref(), move_path.as_ref()]
|
||||
}
|
||||
ApplyPatchFileChange::Update {
|
||||
move_path: None, ..
|
||||
} => vec![path.as_ref()],
|
||||
})
|
||||
.find_map(|path: &Path| {
|
||||
// ApplyPatchAction promises to guarantee absolute paths.
|
||||
if !path.is_absolute() {
|
||||
panic!("apply_patch invariant failed: path is not absolute: {path:?}");
|
||||
}
|
||||
|
||||
let writable = {
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let roots = sess.writable_roots.lock().unwrap();
|
||||
roots.iter().any(|root| path.starts_with(root))
|
||||
};
|
||||
if writable {
|
||||
None
|
||||
} else {
|
||||
Some(path.to_path_buf())
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(offending) = offending_opt {
|
||||
let root = offending.parent().unwrap_or(&offending).to_path_buf();
|
||||
|
||||
let reason = Some(format!(
|
||||
"grant write access to {} for this session",
|
||||
root.display()
|
||||
));
|
||||
let rx = sess
|
||||
.request_patch_approval(
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
&action,
|
||||
reason.clone(),
|
||||
Some(root.clone()),
|
||||
)
|
||||
.await;
|
||||
if matches!(
|
||||
rx.await.unwrap_or_default(),
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession
|
||||
) {
|
||||
// Extend writable roots.
|
||||
#[allow(clippy::unwrap_used)]
|
||||
sess.writable_roots.lock().unwrap().push(root);
|
||||
stdout.clear();
|
||||
stderr.clear();
|
||||
result = apply_changes_from_apply_patch_and_report(
|
||||
&action,
|
||||
&mut stdout,
|
||||
&mut stderr,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit PatchApplyEnd event.
|
||||
let success_flag = result.is_ok();
|
||||
let _ = sess
|
||||
.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id: call_id.clone(),
|
||||
stdout: String::from_utf8_lossy(&stdout).to_string(),
|
||||
stderr: String::from_utf8_lossy(&stderr).to_string(),
|
||||
success: success_flag,
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: String::from_utf8_lossy(&stdout).to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
Err(e) => ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("error: {e:#}, stderr: {}", String::from_utf8_lossy(&stderr)),
|
||||
success: Some(false),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the first path in `hunks` that is NOT under any of the
|
||||
/// `writable_roots` (after normalising). If all paths are acceptable,
|
||||
/// returns None.
|
||||
fn first_offending_path(
|
||||
action: &ApplyPatchAction,
|
||||
writable_roots: &[PathBuf],
|
||||
cwd: &Path,
|
||||
) -> Option<PathBuf> {
|
||||
let changes = action.changes();
|
||||
for (path, change) in changes {
|
||||
let candidate = match change {
|
||||
ApplyPatchFileChange::Add { .. } => path,
|
||||
ApplyPatchFileChange::Delete => path,
|
||||
ApplyPatchFileChange::Update { move_path, .. } => move_path.as_ref().unwrap_or(path),
|
||||
};
|
||||
|
||||
let abs = if candidate.is_absolute() {
|
||||
candidate.clone()
|
||||
} else {
|
||||
cwd.join(candidate)
|
||||
};
|
||||
|
||||
let mut allowed = false;
|
||||
for root in writable_roots {
|
||||
let root_abs = if root.is_absolute() {
|
||||
root.clone()
|
||||
} else {
|
||||
cwd.join(root)
|
||||
};
|
||||
if abs.starts_with(&root_abs) {
|
||||
allowed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return Some(candidate.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn convert_apply_patch_to_protocol(
|
||||
action: &ApplyPatchAction,
|
||||
) -> HashMap<PathBuf, FileChange> {
|
||||
let changes = action.changes();
|
||||
let mut result = HashMap::with_capacity(changes.len());
|
||||
for (path, change) in changes {
|
||||
let protocol_change = match change {
|
||||
ApplyPatchFileChange::Add { content } => FileChange::Add {
|
||||
content: content.clone(),
|
||||
},
|
||||
ApplyPatchFileChange::Delete => FileChange::Delete,
|
||||
ApplyPatchFileChange::Update {
|
||||
unified_diff,
|
||||
move_path,
|
||||
new_content: _new_content,
|
||||
} => FileChange::Update {
|
||||
unified_diff: unified_diff.clone(),
|
||||
move_path: move_path.clone(),
|
||||
},
|
||||
};
|
||||
result.insert(path.clone(), protocol_change);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn apply_changes_from_apply_patch_and_report(
|
||||
action: &ApplyPatchAction,
|
||||
stdout: &mut impl std::io::Write,
|
||||
stderr: &mut impl std::io::Write,
|
||||
) -> std::io::Result<()> {
|
||||
match apply_changes_from_apply_patch(action) {
|
||||
Ok(affected_paths) => {
|
||||
print_summary(&affected_paths, stdout)?;
|
||||
}
|
||||
Err(err) => {
|
||||
writeln!(stderr, "{err:?}")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_changes_from_apply_patch(action: &ApplyPatchAction) -> anyhow::Result<AffectedPaths> {
|
||||
let mut added: Vec<PathBuf> = Vec::new();
|
||||
let mut modified: Vec<PathBuf> = Vec::new();
|
||||
let mut deleted: Vec<PathBuf> = Vec::new();
|
||||
|
||||
let changes = action.changes();
|
||||
for (path, change) in changes {
|
||||
match change {
|
||||
ApplyPatchFileChange::Add { content } => {
|
||||
if let Some(parent) = path.parent() {
|
||||
if !parent.as_os_str().is_empty() {
|
||||
std::fs::create_dir_all(parent).with_context(|| {
|
||||
format!("Failed to create parent directories for {}", path.display())
|
||||
})?;
|
||||
}
|
||||
}
|
||||
std::fs::write(path, content)
|
||||
.with_context(|| format!("Failed to write file {}", path.display()))?;
|
||||
added.push(path.clone());
|
||||
}
|
||||
ApplyPatchFileChange::Delete => {
|
||||
std::fs::remove_file(path)
|
||||
.with_context(|| format!("Failed to delete file {}", path.display()))?;
|
||||
deleted.push(path.clone());
|
||||
}
|
||||
ApplyPatchFileChange::Update {
|
||||
unified_diff: _unified_diff,
|
||||
move_path,
|
||||
new_content,
|
||||
} => {
|
||||
if let Some(move_path) = move_path {
|
||||
if let Some(parent) = move_path.parent() {
|
||||
if !parent.as_os_str().is_empty() {
|
||||
std::fs::create_dir_all(parent).with_context(|| {
|
||||
format!(
|
||||
"Failed to create parent directories for {}",
|
||||
move_path.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
std::fs::rename(path, move_path)
|
||||
.with_context(|| format!("Failed to rename file {}", path.display()))?;
|
||||
std::fs::write(move_path, new_content)?;
|
||||
modified.push(move_path.clone());
|
||||
deleted.push(path.clone());
|
||||
} else {
|
||||
std::fs::write(path, new_content)?;
|
||||
modified.push(path.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AffectedPaths {
|
||||
added,
|
||||
modified,
|
||||
deleted,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_writable_roots(cwd: &Path) -> Vec<PathBuf> {
|
||||
let mut writable_roots = Vec::new();
|
||||
if cfg!(target_os = "macos") {
|
||||
// On macOS, $TMPDIR is private to the user.
|
||||
writable_roots.push(std::env::temp_dir());
|
||||
|
||||
// Allow pyenv to update its shims directory. Without this, any tool
|
||||
// that happens to be managed by `pyenv` will fail with an error like:
|
||||
//
|
||||
// pyenv: cannot rehash: $HOME/.pyenv/shims isn't writable
|
||||
//
|
||||
// which is emitted every time `pyenv` tries to run `rehash` (for
|
||||
// example, after installing a new Python package that drops an entry
|
||||
// point). Although the sandbox is intentionally read‑only by default,
|
||||
// writing to the user's local `pyenv` directory is safe because it
|
||||
// is already user‑writable and scoped to the current user account.
|
||||
if let Ok(home_dir) = std::env::var("HOME") {
|
||||
let pyenv_dir = PathBuf::from(home_dir).join(".pyenv");
|
||||
writable_roots.push(pyenv_dir);
|
||||
}
|
||||
}
|
||||
|
||||
writable_roots.push(cwd.to_path_buf());
|
||||
|
||||
writable_roots
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
use tree_sitter::Parser;
|
||||
use tree_sitter::Tree;
|
||||
use tree_sitter_bash::LANGUAGE as BASH;
|
||||
|
||||
/// Parse the provided bash source using tree-sitter-bash, returning a Tree on
|
||||
/// success or None if parsing failed.
|
||||
pub fn try_parse_bash(bash_lc_arg: &str) -> Option<Tree> {
|
||||
let lang = BASH.into();
|
||||
let mut parser = Parser::new();
|
||||
#[expect(clippy::expect_used)]
|
||||
parser.set_language(&lang).expect("load bash grammar");
|
||||
let old_tree: Option<&Tree> = None;
|
||||
parser.parse(bash_lc_arg, old_tree)
|
||||
}
|
||||
|
||||
/// Parse a script which may contain multiple simple commands joined only by
|
||||
/// the safe logical/pipe/sequencing operators: `&&`, `||`, `;`, `|`.
|
||||
///
|
||||
/// Returns `Some(Vec<command_words>)` if every command is a plain word‑only
|
||||
/// command and the parse tree does not contain disallowed constructs
|
||||
/// (parentheses, redirections, substitutions, control flow, etc.). Otherwise
|
||||
/// returns `None`.
|
||||
pub fn try_parse_word_only_commands_sequence(tree: &Tree, src: &str) -> Option<Vec<Vec<String>>> {
|
||||
if tree.root_node().has_error() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// List of allowed (named) node kinds for a "word only commands sequence".
|
||||
// If we encounter a named node that is not in this list we reject.
|
||||
const ALLOWED_KINDS: &[&str] = &[
|
||||
// top level containers
|
||||
"program",
|
||||
"list",
|
||||
"pipeline",
|
||||
// commands & words
|
||||
"command",
|
||||
"command_name",
|
||||
"word",
|
||||
"string",
|
||||
"string_content",
|
||||
"raw_string",
|
||||
"number",
|
||||
];
|
||||
// Allow only safe punctuation / operator tokens; anything else causes reject.
|
||||
const ALLOWED_PUNCT_TOKENS: &[&str] = &["&&", "||", ";", "|", "\"", "'"];
|
||||
|
||||
let root = tree.root_node();
|
||||
let mut cursor = root.walk();
|
||||
let mut stack = vec![root];
|
||||
let mut command_nodes = Vec::new();
|
||||
while let Some(node) = stack.pop() {
|
||||
let kind = node.kind();
|
||||
if node.is_named() {
|
||||
if !ALLOWED_KINDS.contains(&kind) {
|
||||
return None;
|
||||
}
|
||||
if kind == "command" {
|
||||
command_nodes.push(node);
|
||||
}
|
||||
} else {
|
||||
// Reject any punctuation / operator tokens that are not explicitly allowed.
|
||||
if kind.chars().any(|c| "&;|".contains(c)) && !ALLOWED_PUNCT_TOKENS.contains(&kind) {
|
||||
return None;
|
||||
}
|
||||
if !(ALLOWED_PUNCT_TOKENS.contains(&kind) || kind.trim().is_empty()) {
|
||||
// If it's a quote token or operator it's allowed above; we also allow whitespace tokens.
|
||||
// Any other punctuation like parentheses, braces, redirects, backticks, etc are rejected.
|
||||
return None;
|
||||
}
|
||||
}
|
||||
for child in node.children(&mut cursor) {
|
||||
stack.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
let mut commands = Vec::new();
|
||||
for node in command_nodes {
|
||||
if let Some(words) = parse_plain_command_from_node(node, src) {
|
||||
commands.push(words);
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(commands)
|
||||
}
|
||||
|
||||
fn parse_plain_command_from_node(cmd: tree_sitter::Node, src: &str) -> Option<Vec<String>> {
|
||||
if cmd.kind() != "command" {
|
||||
return None;
|
||||
}
|
||||
let mut words = Vec::new();
|
||||
let mut cursor = cmd.walk();
|
||||
for child in cmd.named_children(&mut cursor) {
|
||||
match child.kind() {
|
||||
"command_name" => {
|
||||
let word_node = child.named_child(0)?;
|
||||
if word_node.kind() != "word" {
|
||||
return None;
|
||||
}
|
||||
words.push(word_node.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
}
|
||||
"word" | "number" => {
|
||||
words.push(child.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
}
|
||||
"string" => {
|
||||
if child.child_count() == 3
|
||||
&& child.child(0)?.kind() == "\""
|
||||
&& child.child(1)?.kind() == "string_content"
|
||||
&& child.child(2)?.kind() == "\""
|
||||
{
|
||||
words.push(child.child(1)?.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
"raw_string" => {
|
||||
let raw_string = child.utf8_text(src.as_bytes()).ok()?;
|
||||
let stripped = raw_string
|
||||
.strip_prefix('\'')
|
||||
.and_then(|s| s.strip_suffix('\''));
|
||||
if let Some(s) = stripped {
|
||||
words.push(s.to_owned());
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
Some(words)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
|
||||
fn parse_seq(src: &str) -> Option<Vec<Vec<String>>> {
|
||||
let tree = try_parse_bash(src)?;
|
||||
try_parse_word_only_commands_sequence(&tree, src)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accepts_single_simple_command() {
|
||||
let cmds = parse_seq("ls -1").unwrap();
|
||||
assert_eq!(cmds, vec![vec!["ls".to_string(), "-1".to_string()]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accepts_multiple_commands_with_allowed_operators() {
|
||||
let src = "ls && pwd; echo 'hi there' | wc -l";
|
||||
let cmds = parse_seq(src).unwrap();
|
||||
let expected: Vec<Vec<String>> = vec![
|
||||
vec!["wc".to_string(), "-l".to_string()],
|
||||
vec!["echo".to_string(), "hi there".to_string()],
|
||||
vec!["pwd".to_string()],
|
||||
vec!["ls".to_string()],
|
||||
];
|
||||
assert_eq!(cmds, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_double_and_single_quoted_strings() {
|
||||
let cmds = parse_seq("echo \"hello world\"").unwrap();
|
||||
assert_eq!(
|
||||
cmds,
|
||||
vec![vec!["echo".to_string(), "hello world".to_string()]]
|
||||
);
|
||||
|
||||
let cmds2 = parse_seq("echo 'hi there'").unwrap();
|
||||
assert_eq!(
|
||||
cmds2,
|
||||
vec![vec!["echo".to_string(), "hi there".to_string()]]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accepts_numbers_as_words() {
|
||||
let cmds = parse_seq("echo 123 456").unwrap();
|
||||
assert_eq!(
|
||||
cmds,
|
||||
vec![vec![
|
||||
"echo".to_string(),
|
||||
"123".to_string(),
|
||||
"456".to_string()
|
||||
]]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_parentheses_and_subshells() {
|
||||
assert!(parse_seq("(ls)").is_none());
|
||||
assert!(parse_seq("ls || (pwd && echo hi)").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_redirections_and_unsupported_operators() {
|
||||
assert!(parse_seq("ls > out.txt").is_none());
|
||||
assert!(parse_seq("echo hi & echo bye").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_command_and_process_substitutions_and_expansions() {
|
||||
assert!(parse_seq("echo $(pwd)").is_none());
|
||||
assert!(parse_seq("echo `pwd`").is_none());
|
||||
assert!(parse_seq("echo $HOME").is_none());
|
||||
assert!(parse_seq("echo \"hi $USER\"").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_variable_assignment_prefix() {
|
||||
assert!(parse_seq("FOO=bar ls").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_trailing_operator_parse_error() {
|
||||
assert!(parse_seq("ls &&").is_none());
|
||||
}
|
||||
}
|
||||
@@ -41,7 +41,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
|
||||
for item in &prompt.input {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
ResponseItem::Message { role, content } => {
|
||||
let mut text = String::new();
|
||||
for c in content {
|
||||
match c {
|
||||
@@ -58,7 +58,6 @@ pub(crate) async fn stream_chat_completions(
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
@@ -260,7 +259,6 @@ async fn process_chat_sse<S>(
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: content.to_string(),
|
||||
}],
|
||||
id: None,
|
||||
};
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
@@ -302,7 +300,6 @@ async fn process_chat_sse<S>(
|
||||
"tool_calls" if fn_call_state.active => {
|
||||
// Build the FunctionCall response item.
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||
arguments: fn_call_state.arguments.clone(),
|
||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||
@@ -405,7 +402,6 @@ where
|
||||
}))) => {
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_item = crate::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![crate::models::ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
|
||||
@@ -15,7 +15,6 @@ use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::chat_completions::AggregateStreamExt;
|
||||
use crate::chat_completions::stream_chat_completions;
|
||||
@@ -43,7 +42,6 @@ pub struct ModelClient {
|
||||
config: Arc<Config>,
|
||||
client: reqwest::Client,
|
||||
provider: ModelProviderInfo,
|
||||
session_id: Uuid,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
}
|
||||
@@ -54,13 +52,11 @@ impl ModelClient {
|
||||
provider: ModelProviderInfo,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_id: Uuid,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
client: reqwest::Client::new(),
|
||||
provider,
|
||||
session_id,
|
||||
effort,
|
||||
summary,
|
||||
}
|
||||
@@ -117,15 +113,6 @@ impl ModelClient {
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
||||
let tools_json = create_tools_json_for_responses_api(prompt, &self.config.model)?;
|
||||
let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary);
|
||||
|
||||
// Request encrypted COT if we are not storing responses,
|
||||
// otherwise reasoning items will be referenced by ID
|
||||
let include = if !prompt.store && reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let payload = ResponsesApiRequest {
|
||||
model: &self.config.model,
|
||||
instructions: &full_instructions,
|
||||
@@ -134,10 +121,10 @@ impl ModelClient {
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
reasoning,
|
||||
previous_response_id: prompt.prev_id.clone(),
|
||||
store: prompt.store,
|
||||
// TODO: make this configurable
|
||||
stream: true,
|
||||
include,
|
||||
};
|
||||
|
||||
trace!(
|
||||
@@ -155,22 +142,10 @@ impl ModelClient {
|
||||
.provider
|
||||
.create_request_builder(&self.client)?
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("session_id", self.session_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload);
|
||||
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
"Response status: {}, request-id: {}",
|
||||
resp.status(),
|
||||
resp.headers()
|
||||
.get("x-request-id")
|
||||
.map(|v| v.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
@@ -394,19 +369,6 @@ async fn process_sse<S>(
|
||||
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
|
||||
}
|
||||
}
|
||||
"response.failed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
let error = resp_val
|
||||
.get("error")
|
||||
.and_then(|v| v.get("message"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("response.failed event received");
|
||||
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(error.to_string())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
// Final response completed – includes array of output items & id
|
||||
"response.completed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
|
||||
@@ -22,6 +22,8 @@ const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
/// Optional previous response ID (when storage is enabled).
|
||||
pub prev_id: Option<String>,
|
||||
/// Optional instructions from the user to amend to the built-in agent
|
||||
/// instructions.
|
||||
pub user_instructions: Option<String>,
|
||||
@@ -32,18 +34,11 @@ pub struct Prompt {
|
||||
/// the "fully qualified" tool name (i.e., prefixed with the server name),
|
||||
/// which should be reported to the model in place of Tool::name.
|
||||
pub extra_tools: HashMap<String, mcp_types::Tool>,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> {
|
||||
let base = self
|
||||
.base_instructions_override
|
||||
.as_deref()
|
||||
.unwrap_or(BASE_INSTRUCTIONS);
|
||||
let mut sections: Vec<&str> = vec![base];
|
||||
let mut sections: Vec<&str> = vec![BASE_INSTRUCTIONS];
|
||||
if let Some(ref user) = self.user_instructions {
|
||||
sections.push(user);
|
||||
}
|
||||
@@ -131,10 +126,11 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) previous_response_id: Option<String>,
|
||||
/// true when using the Responses API.
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
}
|
||||
|
||||
use crate::config::Config;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,35 +1,20 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::Codex;
|
||||
use crate::CodexSpawnOk;
|
||||
use crate::config::Config;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::util::notify_on_sigint;
|
||||
use tokio::sync::Notify;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Represents an active Codex conversation, including the first event
|
||||
/// (which is [`EventMsg::SessionConfigured`]).
|
||||
pub struct CodexConversation {
|
||||
pub codex: Codex,
|
||||
pub session_id: Uuid,
|
||||
pub session_configured: Event,
|
||||
pub ctrl_c: Arc<Notify>,
|
||||
}
|
||||
|
||||
/// Spawn a new [`Codex`] and initialize the session.
|
||||
///
|
||||
/// Returns the wrapped [`Codex`] **and** the `SessionInitialized` event that
|
||||
/// is received as a response to the initial `ConfigureSession` submission so
|
||||
/// that callers can surface the information to the UI.
|
||||
pub async fn init_codex(config: Config) -> anyhow::Result<CodexConversation> {
|
||||
pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc<Notify>)> {
|
||||
let ctrl_c = notify_on_sigint();
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
init_id,
|
||||
session_id,
|
||||
} = Codex::spawn(config, ctrl_c.clone()).await?;
|
||||
let (codex, init_id) = Codex::spawn(config, ctrl_c.clone()).await?;
|
||||
|
||||
// The first event must be `SessionInitialized`. Validate and forward it to
|
||||
// the caller so that they can display it in the conversation history.
|
||||
@@ -48,10 +33,5 @@ pub async fn init_codex(config: Config) -> anyhow::Result<CodexConversation> {
|
||||
));
|
||||
}
|
||||
|
||||
Ok(CodexConversation {
|
||||
codex,
|
||||
session_id,
|
||||
session_configured: event,
|
||||
ctrl_c,
|
||||
})
|
||||
Ok((codex, event, ctrl_c))
|
||||
}
|
||||
|
||||
@@ -63,10 +63,7 @@ pub struct Config {
|
||||
pub disable_response_storage: bool,
|
||||
|
||||
/// User-provided instructions from instructions.md.
|
||||
pub user_instructions: Option<String>,
|
||||
|
||||
/// Base instructions override.
|
||||
pub base_instructions: Option<String>,
|
||||
pub instructions: Option<String>,
|
||||
|
||||
/// Optional external notifier command. When set, Codex will spawn this
|
||||
/// program after each completed *turn* (i.e. when the agent finishes
|
||||
@@ -140,9 +137,6 @@ pub struct Config {
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: String,
|
||||
|
||||
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
|
||||
pub experimental_resume: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -327,12 +321,6 @@ pub struct ConfigToml {
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
|
||||
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
|
||||
pub experimental_resume: Option<PathBuf>,
|
||||
|
||||
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -365,7 +353,6 @@ pub struct ConfigOverrides {
|
||||
pub model_provider: Option<String>,
|
||||
pub config_profile: Option<String>,
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub base_instructions: Option<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -376,7 +363,7 @@ impl Config {
|
||||
overrides: ConfigOverrides,
|
||||
codex_home: PathBuf,
|
||||
) -> std::io::Result<Self> {
|
||||
let user_instructions = Self::load_instructions(Some(&codex_home));
|
||||
let instructions = Self::load_instructions(Some(&codex_home));
|
||||
|
||||
// Destructure ConfigOverrides fully to ensure all overrides are applied.
|
||||
let ConfigOverrides {
|
||||
@@ -387,7 +374,6 @@ impl Config {
|
||||
model_provider,
|
||||
config_profile: config_profile_key,
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions,
|
||||
} = overrides;
|
||||
|
||||
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
|
||||
@@ -462,13 +448,6 @@ impl Config {
|
||||
.as_ref()
|
||||
.map(|info| info.max_output_tokens)
|
||||
});
|
||||
|
||||
let experimental_resume = cfg.experimental_resume;
|
||||
|
||||
let base_instructions = base_instructions.or(Self::get_base_instructions(
|
||||
cfg.experimental_instructions_file.as_ref(),
|
||||
));
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
model_context_window,
|
||||
@@ -487,8 +466,7 @@ impl Config {
|
||||
.or(cfg.disable_response_storage)
|
||||
.unwrap_or(false),
|
||||
notify: cfg.notify,
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
instructions,
|
||||
mcp_servers: cfg.mcp_servers,
|
||||
model_providers,
|
||||
project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES),
|
||||
@@ -516,8 +494,6 @@ impl Config {
|
||||
.chatgpt_base_url
|
||||
.or(cfg.chatgpt_base_url)
|
||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||
|
||||
experimental_resume,
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
@@ -538,15 +514,6 @@ impl Config {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_base_instructions(path: Option<&PathBuf>) -> Option<String> {
|
||||
let path = path.as_ref()?;
|
||||
|
||||
std::fs::read_to_string(path)
|
||||
.ok()
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_model() -> String {
|
||||
@@ -561,7 +528,7 @@ fn default_model() -> String {
|
||||
/// function will Err if the path does not exist.
|
||||
/// - If `CODEX_HOME` is not set, this function does not verify that the
|
||||
/// directory exists.
|
||||
pub fn find_codex_home() -> std::io::Result<PathBuf> {
|
||||
fn find_codex_home() -> std::io::Result<PathBuf> {
|
||||
// Honor the `CODEX_HOME` environment variable when it is set to allow users
|
||||
// (and tests) to override the default location.
|
||||
if let Ok(val) = std::env::var("CODEX_HOME") {
|
||||
@@ -823,7 +790,7 @@ disable_response_storage = true
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
@@ -839,8 +806,6 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::Detailed,
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
},
|
||||
o3_profile_config
|
||||
);
|
||||
@@ -871,7 +836,7 @@ disable_response_storage = true
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
@@ -887,8 +852,6 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||
@@ -934,7 +897,7 @@ disable_response_storage = true
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: true,
|
||||
user_instructions: None,
|
||||
instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
@@ -950,8 +913,6 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||
|
||||
@@ -76,7 +76,20 @@ pub enum HistoryPersistence {
|
||||
|
||||
/// Collection of settings that are specific to the TUI.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Tui {}
|
||||
pub struct Tui {
|
||||
/// By default, mouse capture is enabled in the TUI so that it is possible
|
||||
/// to scroll the conversation history with a mouse. This comes at the cost
|
||||
/// of not being able to use the mouse to select text in the TUI.
|
||||
/// (Most terminals support a modifier key to allow this. For example,
|
||||
/// text selection works in iTerm if you hold down the `Option` key while
|
||||
/// clicking and dragging.)
|
||||
///
|
||||
/// Setting this option to `true` disables mouse capture, so scrolling with
|
||||
/// the mouse is not possible, though the keyboard shortcuts e.g. `b` and
|
||||
/// `space` still work. This allows the user to select text in the TUI
|
||||
/// using the mouse without needing to hold down a modifier key.
|
||||
pub disable_mouse_capture: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Default)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
@@ -130,8 +143,6 @@ pub struct ShellEnvironmentPolicyToml {
|
||||
|
||||
/// List of regular expressions.
|
||||
pub include_only: Option<Vec<String>>,
|
||||
|
||||
pub experimental_use_profile: Option<bool>,
|
||||
}
|
||||
|
||||
pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>;
|
||||
@@ -160,9 +171,6 @@ pub struct ShellEnvironmentPolicy {
|
||||
|
||||
/// Environment variable names to retain in the environment.
|
||||
pub include_only: Vec<EnvironmentVariablePattern>,
|
||||
|
||||
/// If true, the shell profile will be used to run the command.
|
||||
pub use_profile: bool,
|
||||
}
|
||||
|
||||
impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
||||
@@ -182,7 +190,6 @@ impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
||||
.into_iter()
|
||||
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
|
||||
.collect();
|
||||
let use_profile = toml.experimental_use_profile.unwrap_or(false);
|
||||
|
||||
Self {
|
||||
inherit,
|
||||
@@ -190,7 +197,6 @@ impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
||||
exclude,
|
||||
r#set,
|
||||
include_only,
|
||||
use_profile,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
use crate::models::ResponseItem;
|
||||
|
||||
/// Transcript of conversation history
|
||||
#[derive(Debug, Clone, Default)]
|
||||
/// Transcript of conversation history that is needed:
|
||||
/// - for ZDR clients for which previous_response_id is not available, so we
|
||||
/// must include the transcript with every API call. This must include each
|
||||
/// `function_call` and its corresponding `function_call_output`.
|
||||
/// - for clients using the "chat completions" API as opposed to the
|
||||
/// "responses" API.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ConversationHistory {
|
||||
/// The oldest items are at the beginning of the vector.
|
||||
items: Vec<ResponseItem>,
|
||||
@@ -39,8 +44,7 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
ResponseItem::Message { role, .. } => role.as_str() != "system",
|
||||
ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. } => true,
|
||||
ResponseItem::Other => false,
|
||||
| ResponseItem::LocalShellCall { .. } => true,
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ use tokio::io::BufReader;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Notify;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
@@ -83,8 +82,7 @@ pub async fn process_exec_tool_call(
|
||||
) -> Result<ExecToolCallOutput> {
|
||||
let start = Instant::now();
|
||||
|
||||
let raw_output_result: std::result::Result<RawExecToolCallOutput, CodexErr> = match sandbox_type
|
||||
{
|
||||
let raw_output_result = match sandbox_type {
|
||||
SandboxType::None => exec(params, sandbox_policy, ctrl_c).await,
|
||||
SandboxType::MacosSeatbelt => {
|
||||
let ExecParams {
|
||||
@@ -374,10 +372,6 @@ async fn spawn_child_async(
|
||||
stdio_policy: StdioPolicy,
|
||||
env: HashMap<String, String>,
|
||||
) -> std::io::Result<Child> {
|
||||
trace!(
|
||||
"spawn_child_async: {program:?} {args:?} {arg0:?} {cwd:?} {sandbox_policy:?} {stdio_policy:?} {env:?}"
|
||||
);
|
||||
|
||||
let mut cmd = Command::new(&program);
|
||||
#[cfg(unix)]
|
||||
cmd.arg0(arg0.map_or_else(|| program.to_string_lossy().to_string(), String::from));
|
||||
@@ -390,31 +384,6 @@ async fn spawn_child_async(
|
||||
cmd.env(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR, "1");
|
||||
}
|
||||
|
||||
// If this Codex process dies (including being killed via SIGKILL), we want
|
||||
// any child processes that were spawned as part of a `"shell"` tool call
|
||||
// to also be terminated.
|
||||
|
||||
// This relies on prctl(2), so it only works on Linux.
|
||||
#[cfg(target_os = "linux")]
|
||||
unsafe {
|
||||
cmd.pre_exec(|| {
|
||||
// This prctl call effectively requests, "deliver SIGTERM when my
|
||||
// current parent dies."
|
||||
if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) == -1 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
// Though if there was a race condition and this pre_exec() block is
|
||||
// run _after_ the parent (i.e., the Codex process) has already
|
||||
// exited, then the parent is the _init_ process (which will never
|
||||
// die), so we should just terminate the child process now.
|
||||
if libc::getppid() == 1 {
|
||||
libc::raise(libc::SIGTERM);
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
match stdio_policy {
|
||||
StdioPolicy::RedirectForShellTool => {
|
||||
// Do not create a file descriptor for stdin because otherwise some
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::Duration as TokioDuration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
/// Timeout for git commands to prevent freezing on large repositories
|
||||
const GIT_COMMAND_TIMEOUT: TokioDuration = TokioDuration::from_secs(5);
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct GitInfo {
|
||||
/// Current commit hash (SHA)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub commit_hash: Option<String>,
|
||||
/// Current branch name
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub branch: Option<String>,
|
||||
/// Repository URL (if available from remote)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repository_url: Option<String>,
|
||||
}
|
||||
|
||||
/// Collect git repository information from the given working directory using command-line git.
|
||||
/// Returns None if no git repository is found or if git operations fail.
|
||||
/// Uses timeouts to prevent freezing on large repositories.
|
||||
/// All git commands (except the initial repo check) run in parallel for better performance.
|
||||
pub async fn collect_git_info(cwd: &Path) -> Option<GitInfo> {
|
||||
// Check if we're in a git repository first
|
||||
let is_git_repo = run_git_command_with_timeout(&["rev-parse", "--git-dir"], cwd)
|
||||
.await?
|
||||
.status
|
||||
.success();
|
||||
|
||||
if !is_git_repo {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Run all git info collection commands in parallel
|
||||
let (commit_result, branch_result, url_result) = tokio::join!(
|
||||
run_git_command_with_timeout(&["rev-parse", "HEAD"], cwd),
|
||||
run_git_command_with_timeout(&["rev-parse", "--abbrev-ref", "HEAD"], cwd),
|
||||
run_git_command_with_timeout(&["remote", "get-url", "origin"], cwd)
|
||||
);
|
||||
|
||||
let mut git_info = GitInfo {
|
||||
commit_hash: None,
|
||||
branch: None,
|
||||
repository_url: None,
|
||||
};
|
||||
|
||||
// Process commit hash
|
||||
if let Some(output) = commit_result {
|
||||
if output.status.success() {
|
||||
if let Ok(hash) = String::from_utf8(output.stdout) {
|
||||
git_info.commit_hash = Some(hash.trim().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process branch name
|
||||
if let Some(output) = branch_result {
|
||||
if output.status.success() {
|
||||
if let Ok(branch) = String::from_utf8(output.stdout) {
|
||||
let branch = branch.trim();
|
||||
if branch != "HEAD" {
|
||||
git_info.branch = Some(branch.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process repository URL
|
||||
if let Some(output) = url_result {
|
||||
if output.status.success() {
|
||||
if let Ok(url) = String::from_utf8(output.stdout) {
|
||||
git_info.repository_url = Some(url.trim().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(git_info)
|
||||
}
|
||||
|
||||
/// Run a git command with a timeout to prevent blocking on large repositories
|
||||
async fn run_git_command_with_timeout(args: &[&str], cwd: &Path) -> Option<std::process::Output> {
|
||||
let result = timeout(
|
||||
GIT_COMMAND_TIMEOUT,
|
||||
Command::new("git").args(args).current_dir(cwd).output(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => Some(output),
|
||||
_ => None, // Timeout or error
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used)]
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use super::*;
|
||||
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
|
||||
// Helper function to create a test git repository
|
||||
async fn create_test_git_repo(temp_dir: &TempDir) -> PathBuf {
|
||||
let repo_path = temp_dir.path().to_path_buf();
|
||||
|
||||
// Initialize git repo
|
||||
Command::new("git")
|
||||
.args(["init"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to init git repo");
|
||||
|
||||
// Configure git user (required for commits)
|
||||
Command::new("git")
|
||||
.args(["config", "user.name", "Test User"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to set git user name");
|
||||
|
||||
Command::new("git")
|
||||
.args(["config", "user.email", "test@example.com"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to set git user email");
|
||||
|
||||
// Create a test file and commit it
|
||||
let test_file = repo_path.join("test.txt");
|
||||
fs::write(&test_file, "test content").expect("Failed to write test file");
|
||||
|
||||
Command::new("git")
|
||||
.args(["add", "."])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to add files");
|
||||
|
||||
Command::new("git")
|
||||
.args(["commit", "-m", "Initial commit"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to commit");
|
||||
|
||||
repo_path
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_collect_git_info_non_git_directory() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let result = collect_git_info(temp_dir.path()).await;
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_collect_git_info_git_repository() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
|
||||
let git_info = collect_git_info(&repo_path)
|
||||
.await
|
||||
.expect("Should collect git info from repo");
|
||||
|
||||
// Should have commit hash
|
||||
assert!(git_info.commit_hash.is_some());
|
||||
let commit_hash = git_info.commit_hash.unwrap();
|
||||
assert_eq!(commit_hash.len(), 40); // SHA-1 hash should be 40 characters
|
||||
assert!(commit_hash.chars().all(|c| c.is_ascii_hexdigit()));
|
||||
|
||||
// Should have branch (likely "main" or "master")
|
||||
assert!(git_info.branch.is_some());
|
||||
let branch = git_info.branch.unwrap();
|
||||
assert!(branch == "main" || branch == "master");
|
||||
|
||||
// Repository URL might be None for local repos without remote
|
||||
// This is acceptable behavior
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_collect_git_info_with_remote() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
|
||||
// Add a remote origin
|
||||
Command::new("git")
|
||||
.args([
|
||||
"remote",
|
||||
"add",
|
||||
"origin",
|
||||
"https://github.com/example/repo.git",
|
||||
])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to add remote");
|
||||
|
||||
let git_info = collect_git_info(&repo_path)
|
||||
.await
|
||||
.expect("Should collect git info from repo");
|
||||
|
||||
// Should have repository URL
|
||||
assert_eq!(
|
||||
git_info.repository_url,
|
||||
Some("https://github.com/example/repo.git".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_collect_git_info_detached_head() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
|
||||
// Get the current commit hash
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", "HEAD"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to get HEAD");
|
||||
let commit_hash = String::from_utf8(output.stdout).unwrap().trim().to_string();
|
||||
|
||||
// Checkout the commit directly (detached HEAD)
|
||||
Command::new("git")
|
||||
.args(["checkout", &commit_hash])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to checkout commit");
|
||||
|
||||
let git_info = collect_git_info(&repo_path)
|
||||
.await
|
||||
.expect("Should collect git info from repo");
|
||||
|
||||
// Should have commit hash
|
||||
assert!(git_info.commit_hash.is_some());
|
||||
// Branch should be None for detached HEAD (since rev-parse --abbrev-ref HEAD returns "HEAD")
|
||||
assert!(git_info.branch.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_collect_git_info_with_branch() {
|
||||
let temp_dir = TempDir::new().expect("Failed to create temp dir");
|
||||
let repo_path = create_test_git_repo(&temp_dir).await;
|
||||
|
||||
// Create and checkout a new branch
|
||||
Command::new("git")
|
||||
.args(["checkout", "-b", "feature-branch"])
|
||||
.current_dir(&repo_path)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to create branch");
|
||||
|
||||
let git_info = collect_git_info(&repo_path)
|
||||
.await
|
||||
.expect("Should collect git info from repo");
|
||||
|
||||
// Should have the new branch name
|
||||
assert_eq!(git_info.branch, Some("feature-branch".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_info_serialization() {
|
||||
let git_info = GitInfo {
|
||||
commit_hash: Some("abc123def456".to_string()),
|
||||
branch: Some("main".to_string()),
|
||||
repository_url: Some("https://github.com/example/repo.git".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&git_info).expect("Should serialize GitInfo");
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).expect("Should parse JSON");
|
||||
|
||||
assert_eq!(parsed["commit_hash"], "abc123def456");
|
||||
assert_eq!(parsed["branch"], "main");
|
||||
assert_eq!(
|
||||
parsed["repository_url"],
|
||||
"https://github.com/example/repo.git"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_info_serialization_with_nones() {
|
||||
let git_info = GitInfo {
|
||||
commit_hash: None,
|
||||
branch: None,
|
||||
repository_url: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&git_info).expect("Should serialize GitInfo");
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).expect("Should parse JSON");
|
||||
|
||||
// Fields with None values should be omitted due to skip_serializing_if
|
||||
assert!(!parsed.as_object().unwrap().contains_key("commit_hash"));
|
||||
assert!(!parsed.as_object().unwrap().contains_key("branch"));
|
||||
assert!(!parsed.as_object().unwrap().contains_key("repository_url"));
|
||||
}
|
||||
}
|
||||
@@ -1,57 +1,31 @@
|
||||
use crate::bash::try_parse_bash;
|
||||
use crate::bash::try_parse_word_only_commands_sequence;
|
||||
use tree_sitter::Parser;
|
||||
use tree_sitter::Tree;
|
||||
use tree_sitter_bash::LANGUAGE as BASH;
|
||||
|
||||
pub fn is_known_safe_command(command: &[String]) -> bool {
|
||||
if is_safe_to_call_with_exec(command) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Support `bash -lc "..."` where the script consists solely of one or
|
||||
// more "plain" commands (only bare words / quoted strings) combined with
|
||||
// a conservative allow‑list of shell operators that themselves do not
|
||||
// introduce side effects ( "&&", "||", ";", and "|" ). If every
|
||||
// individual command in the script is itself a known‑safe command, then
|
||||
// the composite expression is considered safe.
|
||||
if let [bash, flag, script] = command {
|
||||
if bash == "bash" && flag == "-lc" {
|
||||
if let Some(tree) = try_parse_bash(script) {
|
||||
if let Some(all_commands) = try_parse_word_only_commands_sequence(&tree, script) {
|
||||
if !all_commands.is_empty()
|
||||
&& all_commands
|
||||
.iter()
|
||||
.all(|cmd| is_safe_to_call_with_exec(cmd))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
// TODO(mbolin): Also support safe commands that are piped together such
|
||||
// as `cat foo | wc -l`.
|
||||
matches!(
|
||||
command,
|
||||
[bash, flag, script]
|
||||
if bash == "bash"
|
||||
&& flag == "-lc"
|
||||
&& try_parse_bash(script).and_then(|tree|
|
||||
try_parse_single_word_only_command(&tree, script)).is_some_and(|parsed_bash_command| is_safe_to_call_with_exec(&parsed_bash_command))
|
||||
)
|
||||
}
|
||||
|
||||
fn is_safe_to_call_with_exec(command: &[String]) -> bool {
|
||||
let cmd0 = command.first().map(String::as_str);
|
||||
|
||||
match cmd0 {
|
||||
#[rustfmt::skip]
|
||||
Some(
|
||||
"cat" |
|
||||
"cd" |
|
||||
"echo" |
|
||||
"false" |
|
||||
"grep" |
|
||||
"head" |
|
||||
"ls" |
|
||||
"nl" |
|
||||
"pwd" |
|
||||
"tail" |
|
||||
"true" |
|
||||
"wc" |
|
||||
"which") => {
|
||||
true
|
||||
},
|
||||
"cat" | "cd" | "echo" | "grep" | "head" | "ls" | "pwd" | "rg" | "tail" | "wc" | "which",
|
||||
) => true,
|
||||
|
||||
Some("find") => {
|
||||
// Certain options to `find` can delete files, write to files, or
|
||||
@@ -72,29 +46,6 @@ fn is_safe_to_call_with_exec(command: &[String]) -> bool {
|
||||
.any(|arg| UNSAFE_FIND_OPTIONS.contains(&arg.as_str()))
|
||||
}
|
||||
|
||||
// Ripgrep
|
||||
Some("rg") => {
|
||||
const UNSAFE_RIPGREP_OPTIONS_WITH_ARGS: &[&str] = &[
|
||||
// Takes an arbitrary command that is executed for each match.
|
||||
"--pre",
|
||||
// Takes a command that can be used to obtain the local hostname.
|
||||
"--hostname-bin",
|
||||
];
|
||||
const UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS: &[&str] = &[
|
||||
// Calls out to other decompression tools, so do not auto-approve
|
||||
// out of an abundance of caution.
|
||||
"--search-zip",
|
||||
"-z",
|
||||
];
|
||||
|
||||
!command.iter().any(|arg| {
|
||||
UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS.contains(&arg.as_str())
|
||||
|| UNSAFE_RIPGREP_OPTIONS_WITH_ARGS
|
||||
.iter()
|
||||
.any(|&opt| arg == opt || arg.starts_with(&format!("{opt}=")))
|
||||
})
|
||||
}
|
||||
|
||||
// Git
|
||||
Some("git") => matches!(
|
||||
command.get(1).map(String::as_str),
|
||||
@@ -121,7 +72,90 @@ fn is_safe_to_call_with_exec(command: &[String]) -> bool {
|
||||
}
|
||||
}
|
||||
|
||||
// (bash parsing helpers implemented in crate::bash)
|
||||
fn try_parse_bash(bash_lc_arg: &str) -> Option<Tree> {
|
||||
let lang = BASH.into();
|
||||
let mut parser = Parser::new();
|
||||
#[expect(clippy::expect_used)]
|
||||
parser.set_language(&lang).expect("load bash grammar");
|
||||
|
||||
let old_tree: Option<&Tree> = None;
|
||||
parser.parse(bash_lc_arg, old_tree)
|
||||
}
|
||||
|
||||
/// If `tree` represents a single Bash command whose name and every argument is
|
||||
/// an ordinary `word`, return those words in order; otherwise, return `None`.
|
||||
///
|
||||
/// `src` must be the exact source string that was parsed into `tree`, so we can
|
||||
/// extract the text for every node.
|
||||
pub fn try_parse_single_word_only_command(tree: &Tree, src: &str) -> Option<Vec<String>> {
|
||||
// Any parse error is an immediate rejection.
|
||||
if tree.root_node().has_error() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// (program …) with exactly one statement
|
||||
let root = tree.root_node();
|
||||
if root.kind() != "program" || root.named_child_count() != 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let cmd = root.named_child(0)?; // (command …)
|
||||
if cmd.kind() != "command" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut words = Vec::new();
|
||||
let mut cursor = cmd.walk();
|
||||
|
||||
for child in cmd.named_children(&mut cursor) {
|
||||
match child.kind() {
|
||||
// The command name node wraps one `word` child.
|
||||
"command_name" => {
|
||||
let word_node = child.named_child(0)?; // make sure it's only a word
|
||||
if word_node.kind() != "word" {
|
||||
return None;
|
||||
}
|
||||
words.push(word_node.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
}
|
||||
// Positional‑argument word (allowed).
|
||||
"word" | "number" => {
|
||||
words.push(child.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
}
|
||||
"string" => {
|
||||
if child.child_count() == 3
|
||||
&& child.child(0)?.kind() == "\""
|
||||
&& child.child(1)?.kind() == "string_content"
|
||||
&& child.child(2)?.kind() == "\""
|
||||
{
|
||||
words.push(child.child(1)?.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
} else {
|
||||
// Anything else means the command is *not* plain words.
|
||||
return None;
|
||||
}
|
||||
}
|
||||
"concatenation" => {
|
||||
// TODO: Consider things like `'ab\'a'`.
|
||||
return None;
|
||||
}
|
||||
"raw_string" => {
|
||||
// Raw string is a single word, but we need to strip the quotes.
|
||||
let raw_string = child.utf8_text(src.as_bytes()).ok()?;
|
||||
let stripped = raw_string
|
||||
.strip_prefix('\'')
|
||||
.and_then(|s| s.strip_suffix('\''));
|
||||
if let Some(stripped) = stripped {
|
||||
words.push(stripped.to_owned());
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
// Anything else means the command is *not* plain words.
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
|
||||
Some(words)
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------
|
||||
Example
|
||||
@@ -159,7 +193,6 @@ fn is_valid_sed_n_arg(arg: Option<&str>) -> bool {
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
@@ -176,11 +209,6 @@ mod tests {
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
"sed", "-n", "1,5p", "file.txt"
|
||||
])));
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
"nl",
|
||||
"-nrz",
|
||||
"Cargo.toml"
|
||||
])));
|
||||
|
||||
// Safe `find` command (no unsafe options).
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
@@ -217,40 +245,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ripgrep_rules() {
|
||||
// Safe ripgrep invocations – none of the unsafe flags are present.
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
"rg",
|
||||
"Cargo.toml",
|
||||
"-n"
|
||||
])));
|
||||
|
||||
// Unsafe flags that do not take an argument (present verbatim).
|
||||
for args in [
|
||||
vec_str(&["rg", "--search-zip", "files"]),
|
||||
vec_str(&["rg", "-z", "files"]),
|
||||
] {
|
||||
assert!(
|
||||
!is_safe_to_call_with_exec(&args),
|
||||
"expected {args:?} to be considered unsafe due to zip-search flag",
|
||||
);
|
||||
}
|
||||
|
||||
// Unsafe flags that expect a value, provided in both split and = forms.
|
||||
for args in [
|
||||
vec_str(&["rg", "--pre", "pwned", "files"]),
|
||||
vec_str(&["rg", "--pre=pwned", "files"]),
|
||||
vec_str(&["rg", "--hostname-bin", "pwned", "files"]),
|
||||
vec_str(&["rg", "--hostname-bin=pwned", "files"]),
|
||||
] {
|
||||
assert!(
|
||||
!is_safe_to_call_with_exec(&args),
|
||||
"expected {args:?} to be considered unsafe due to external-command flag",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_lc_safe_examples() {
|
||||
assert!(is_known_safe_command(&vec_str(&["bash", "-lc", "ls"])));
|
||||
@@ -283,30 +277,6 @@ mod tests {
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_lc_safe_examples_with_operators() {
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"grep -R \"Cargo.toml\" -n || true"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"ls && pwd"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"echo 'hi' ; ls"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"ls | wc -l"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_lc_unsafe_examples() {
|
||||
assert!(
|
||||
@@ -320,29 +290,44 @@ mod tests {
|
||||
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "find . -name file.txt -delete"])),
|
||||
"Unsafe find option should not be auto-approved."
|
||||
);
|
||||
|
||||
// Disallowed because of unsafe command in sequence.
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls && rm -rf /"])),
|
||||
"Sequence containing unsafe command must be rejected"
|
||||
);
|
||||
|
||||
// Disallowed because of parentheses / subshell.
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "(ls)"])),
|
||||
"Parentheses (subshell) are not provably safe with the current parser"
|
||||
);
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls || (pwd && echo hi)"])),
|
||||
"Nested parentheses are not provably safe with the current parser"
|
||||
);
|
||||
|
||||
// Disallowed redirection.
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls > out.txt"])),
|
||||
"> redirection should be rejected"
|
||||
"Unsafe find option should not be auto‑approved."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_single_word_only_command() {
|
||||
let script_with_single_quoted_string = "sed -n '1,5p' file.txt";
|
||||
let parsed_words = try_parse_bash(script_with_single_quoted_string)
|
||||
.and_then(|tree| {
|
||||
try_parse_single_word_only_command(&tree, script_with_single_quoted_string)
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
vec![
|
||||
"sed".to_string(),
|
||||
"-n".to_string(),
|
||||
// Ensure the single quotes are properly removed.
|
||||
"1,5p".to_string(),
|
||||
"file.txt".to_string()
|
||||
],
|
||||
parsed_words,
|
||||
);
|
||||
|
||||
let script_with_number_arg = "ls -1";
|
||||
let parsed_words = try_parse_bash(script_with_number_arg)
|
||||
.and_then(|tree| try_parse_single_word_only_command(&tree, script_with_number_arg))
|
||||
.unwrap();
|
||||
assert_eq!(vec!["ls", "-1"], parsed_words,);
|
||||
|
||||
let script_with_double_quoted_string_with_no_funny_stuff_arg = "grep -R \"Cargo.toml\" -n";
|
||||
let parsed_words = try_parse_bash(script_with_double_quoted_string_with_no_funny_stuff_arg)
|
||||
.and_then(|tree| {
|
||||
try_parse_single_word_only_command(
|
||||
&tree,
|
||||
script_with_double_quoted_string_with_no_funny_stuff_arg,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(vec!["grep", "-R", "Cargo.toml", "-n"], parsed_words);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,14 +5,11 @@
|
||||
// the TUI or the tracing stack).
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
mod apply_patch;
|
||||
mod bash;
|
||||
mod chat_completions;
|
||||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
pub use codex::Codex;
|
||||
pub use codex::CodexSpawnOk;
|
||||
pub mod codex_wrapper;
|
||||
pub mod config;
|
||||
pub mod config_profile;
|
||||
@@ -22,7 +19,6 @@ pub mod error;
|
||||
pub mod exec;
|
||||
pub mod exec_env;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
mod is_safe_command;
|
||||
mod mcp_connection_manager;
|
||||
mod mcp_tool_call;
|
||||
@@ -38,7 +34,6 @@ mod project_doc;
|
||||
pub mod protocol;
|
||||
mod rollout;
|
||||
mod safety;
|
||||
pub mod shell;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::OsString;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -19,7 +18,6 @@ use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::Tool;
|
||||
|
||||
use serde_json::json;
|
||||
use sha1::Digest;
|
||||
use sha1::Sha1;
|
||||
use tokio::task::JoinSet;
|
||||
@@ -128,12 +126,7 @@ impl McpConnectionManager {
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { command, args, env } = cfg;
|
||||
let client_res = McpClient::new_stdio_client(
|
||||
command.into(),
|
||||
args.into_iter().map(OsString::from).collect(),
|
||||
env,
|
||||
)
|
||||
.await;
|
||||
let client_res = McpClient::new_stdio_client(command, args, env).await;
|
||||
match client_res {
|
||||
Ok(client) => {
|
||||
// Initialize the client.
|
||||
@@ -142,14 +135,10 @@ impl McpConnectionManager {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
|
||||
// indicates this should be an empty object.
|
||||
elicitation: Some(json!({})),
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
@@ -299,8 +288,6 @@ mod tests {
|
||||
r#type: "object".to_string(),
|
||||
},
|
||||
name: tool_name.to_string(),
|
||||
output_schema: None,
|
||||
title: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ use std::collections::HashMap;
|
||||
use base64::Engine;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serialize;
|
||||
use serde::ser::Serializer;
|
||||
|
||||
@@ -38,14 +37,12 @@ pub enum ContentItem {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseItem {
|
||||
Message {
|
||||
id: Option<String>,
|
||||
role: String,
|
||||
content: Vec<ContentItem>,
|
||||
},
|
||||
Reasoning {
|
||||
id: String,
|
||||
summary: Vec<ReasoningItemReasoningSummary>,
|
||||
encrypted_content: Option<String>,
|
||||
},
|
||||
LocalShellCall {
|
||||
/// Set when using the chat completions API.
|
||||
@@ -56,7 +53,6 @@ pub enum ResponseItem {
|
||||
action: LocalShellAction,
|
||||
},
|
||||
FunctionCall {
|
||||
id: Option<String>,
|
||||
name: String,
|
||||
// The Responses API returns the function call arguments as a *string* that contains
|
||||
// JSON, not as an already‑parsed object. We keep it as a raw string here and let
|
||||
@@ -82,11 +78,7 @@ pub enum ResponseItem {
|
||||
impl From<ResponseInputItem> for ResponseItem {
|
||||
fn from(item: ResponseInputItem) -> Self {
|
||||
match item {
|
||||
ResponseInputItem::Message { role, content } => Self::Message {
|
||||
role,
|
||||
content,
|
||||
id: None,
|
||||
},
|
||||
ResponseInputItem::Message { role, content } => Self::Message { role, content },
|
||||
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||
Self::FunctionCallOutput { call_id, output }
|
||||
}
|
||||
@@ -185,7 +177,7 @@ pub struct ShellToolCallParams {
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct FunctionCallOutputPayload {
|
||||
pub content: String,
|
||||
#[expect(dead_code)]
|
||||
@@ -213,19 +205,6 @@ impl Serialize for FunctionCallOutputPayload {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for FunctionCallOutputPayload {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
Ok(FunctionCallOutputPayload {
|
||||
content: s,
|
||||
success: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Display so callers can treat the payload like a plain string when logging or doing
|
||||
// trivial substring checks in tests (existing tests call `.contains()` on the output). Display
|
||||
// returns the raw `content` field.
|
||||
|
||||
@@ -27,16 +27,16 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
|
||||
/// string of instructions.
|
||||
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
match find_project_doc(config).await {
|
||||
Ok(Some(project_doc)) => match &config.user_instructions {
|
||||
Ok(Some(project_doc)) => match &config.instructions {
|
||||
Some(original_instructions) => Some(format!(
|
||||
"{original_instructions}{PROJECT_DOC_SEPARATOR}{project_doc}"
|
||||
)),
|
||||
None => Some(project_doc),
|
||||
},
|
||||
Ok(None) => config.user_instructions.clone(),
|
||||
Ok(None) => config.instructions.clone(),
|
||||
Err(e) => {
|
||||
error!("error trying to find project doc: {e:#}");
|
||||
config.user_instructions.clone()
|
||||
config.instructions.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -159,7 +159,7 @@ mod tests {
|
||||
config.cwd = root.path().to_path_buf();
|
||||
config.project_doc_max_bytes = limit;
|
||||
|
||||
config.user_instructions = instructions.map(ToOwned::to_owned);
|
||||
config.instructions = instructions.map(ToOwned::to_owned);
|
||||
config
|
||||
}
|
||||
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
//! between user and agent.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr; // Added for FinalOutput Display implementation
|
||||
use std::str::FromStr;
|
||||
|
||||
use mcp_types::CallToolResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use strum_macros::Display;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
@@ -46,12 +44,8 @@ pub enum Op {
|
||||
model_reasoning_effort: ReasoningEffortConfig,
|
||||
model_reasoning_summary: ReasoningSummaryConfig,
|
||||
|
||||
/// Model instructions that are appended to the base instructions.
|
||||
user_instructions: Option<String>,
|
||||
|
||||
/// Base instructions override.
|
||||
base_instructions: Option<String>,
|
||||
|
||||
/// Model instructions
|
||||
instructions: Option<String>,
|
||||
/// When to escalate for approval for execution
|
||||
approval_policy: AskForApproval,
|
||||
/// How to sandbox commands executed in the system
|
||||
@@ -75,10 +69,6 @@ pub enum Op {
|
||||
/// `ConfigureSession` operation so that the business-logic layer can
|
||||
/// operate deterministically.
|
||||
cwd: std::path::PathBuf,
|
||||
|
||||
/// Path to a rollout file to resume from.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
resume_path: Option<std::path::PathBuf>,
|
||||
},
|
||||
|
||||
/// Abort current task.
|
||||
@@ -118,23 +108,18 @@ pub enum Op {
|
||||
|
||||
/// Request a single history entry identified by `log_id` + `offset`.
|
||||
GetHistoryEntryRequest { offset: usize, log_id: u64 },
|
||||
|
||||
/// Request to shut down codex instance.
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// Determines the conditions under which the user is consulted to approve
|
||||
/// running the command proposed by Codex.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize, Display)]
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
pub enum AskForApproval {
|
||||
/// Under this policy, only "known safe" commands—as determined by
|
||||
/// `is_safe_command()`—that **only read files** are auto‑approved.
|
||||
/// Everything else will ask the user to approve.
|
||||
#[default]
|
||||
#[serde(rename = "untrusted")]
|
||||
#[strum(serialize = "untrusted")]
|
||||
UnlessTrusted,
|
||||
|
||||
/// *All* commands are auto‑approved, but they are expected to run inside a
|
||||
@@ -278,9 +263,8 @@ pub struct Event {
|
||||
}
|
||||
|
||||
/// Response event from the agent
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Display)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum EventMsg {
|
||||
/// Error while executing a submission
|
||||
Error(ErrorEvent),
|
||||
@@ -334,9 +318,6 @@ pub enum EventMsg {
|
||||
|
||||
/// Response to GetHistoryEntryRequest.
|
||||
GetHistoryEntryResponse(GetHistoryEntryResponseEvent),
|
||||
|
||||
/// Notification that the agent is shutting down.
|
||||
ShutdownComplete,
|
||||
}
|
||||
|
||||
// Individual event payload types matching each `EventMsg` variant.
|
||||
@@ -360,36 +341,6 @@ pub struct TokenUsage {
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FinalOutput {
|
||||
pub token_usage: TokenUsage,
|
||||
}
|
||||
|
||||
impl From<TokenUsage> for FinalOutput {
|
||||
fn from(token_usage: TokenUsage) -> Self {
|
||||
Self { token_usage }
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FinalOutput {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let u = &self.token_usage;
|
||||
write!(
|
||||
f,
|
||||
"Token usage: total={} input={}{} output={}{}",
|
||||
u.total_tokens,
|
||||
u.input_tokens,
|
||||
u.cached_input_tokens
|
||||
.map(|c| format!(" (cached {c})"))
|
||||
.unwrap_or_default(),
|
||||
u.output_tokens,
|
||||
u.reasoning_output_tokens
|
||||
.map(|r| format!(" (reasoning {r})"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentMessageEvent {
|
||||
pub message: String,
|
||||
@@ -463,8 +414,6 @@ pub struct ExecCommandEndEvent {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ExecApprovalRequestEvent {
|
||||
/// Identifier for the associated exec call, if available.
|
||||
pub call_id: String,
|
||||
/// The command to be executed.
|
||||
pub command: Vec<String>,
|
||||
/// The command's working directory.
|
||||
@@ -476,8 +425,6 @@ pub struct ExecApprovalRequestEvent {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ApplyPatchApprovalRequestEvent {
|
||||
/// Responses API call id for the associated patch apply call, if available.
|
||||
pub call_id: String,
|
||||
pub changes: HashMap<PathBuf, FileChange>,
|
||||
/// Optional explanatory reason (e.g. request for extra write access).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
||||
@@ -1,57 +1,33 @@
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
//! Functionality to persist a Codex conversation *rollout* – a linear list of
|
||||
//! [`ResponseItem`] objects exchanged during a session – to disk so that
|
||||
//! sessions can be replayed or inspected later (mirrors the behaviour of the
|
||||
//! upstream TypeScript implementation).
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::git_info::GitInfo;
|
||||
use crate::git_info::collect_git_info;
|
||||
use crate::models::ResponseItem;
|
||||
|
||||
/// Folder inside `~/.codex` that holds saved rollouts.
|
||||
const SESSIONS_SUBDIR: &str = "sessions";
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||
pub struct SessionMeta {
|
||||
pub id: Uuid,
|
||||
pub timestamp: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SessionMetaWithGit {
|
||||
#[serde(flatten)]
|
||||
meta: SessionMeta,
|
||||
struct SessionMeta {
|
||||
id: String,
|
||||
timestamp: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
git: Option<GitInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: Uuid,
|
||||
instructions: Option<String>,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
@@ -65,13 +41,7 @@ pub struct SavedSession {
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<ResponseItem>),
|
||||
UpdateState(SessionStateSnapshot),
|
||||
Shutdown { ack: oneshot::Sender<()> },
|
||||
tx: Sender<String>,
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
@@ -89,6 +59,7 @@ impl RolloutRecorder {
|
||||
timestamp,
|
||||
} = create_log_file(config, uuid)?;
|
||||
|
||||
// Build the static session metadata JSON first.
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
@@ -96,33 +67,48 @@ impl RolloutRecorder {
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
// Clone the cwd for the spawned task to collect git info asynchronously
|
||||
let cwd = config.cwd.clone();
|
||||
let meta = SessionMeta {
|
||||
timestamp,
|
||||
id: session_id.to_string(),
|
||||
instructions,
|
||||
};
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller's thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
// perform *blocking* I/O on the caller’s thread.
|
||||
let (tx, mut rx) = mpsc::channel::<String>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(SessionMeta {
|
||||
timestamp,
|
||||
id: session_id,
|
||||
instructions,
|
||||
}),
|
||||
cwd,
|
||||
));
|
||||
tokio::task::spawn(async move {
|
||||
let mut file = tokio::fs::File::from_std(file);
|
||||
|
||||
Ok(Self { tx })
|
||||
while let Some(line) = rx.recv().await {
|
||||
// Write line + newline, then flush to disk.
|
||||
if let Err(e) = file.write_all(line.as_bytes()).await {
|
||||
tracing::warn!("rollout writer: failed to write line: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = file.write_all(b"\n").await {
|
||||
tracing::warn!("rollout writer: failed to write newline: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = file.flush().await {
|
||||
tracing::warn!("rollout writer: failed to flush: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let recorder = Self { tx };
|
||||
// Ensure SessionMeta is the first item in the file.
|
||||
recorder.record_item(&meta).await?;
|
||||
Ok(recorder)
|
||||
}
|
||||
|
||||
/// Append `items` to the rollout file.
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
match item {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
@@ -131,114 +117,27 @@ impl RolloutRecorder {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => filtered.push(item.clone()),
|
||||
ResponseItem::Other => {
|
||||
| ResponseItem::FunctionCallOutput { .. } => {}
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
||||
// These should never be serialized.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
self.record_item(item).await?;
|
||||
}
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn record_item(&self, item: &impl Serialize) -> std::io::Result<()> {
|
||||
// Serialize the item to JSON first so that the writer thread only has
|
||||
// to perform the actual write.
|
||||
let json = serde_json::to_string(item)
|
||||
.map_err(|e| IoError::other(format!("failed to serialize response items: {e}")))?;
|
||||
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.send(json)
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||
self.tx
|
||||
.send(RolloutCmd::UpdateState(state))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
}
|
||||
|
||||
pub async fn resume(
|
||||
path: &Path,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<(Self, SavedSession)> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let session: SessionMeta = serde_json::from_str(meta_line)
|
||||
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
||||
let mut items = Vec::new();
|
||||
let mut state = SessionStateSnapshot::default();
|
||||
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if v.get("record_type")
|
||||
.and_then(|rt| rt.as_str())
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
||||
state = s
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
Ok(item) => match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => items.push(item),
|
||||
ResponseItem::Other => {}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("failed to parse item: {v:?}, error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let saved = SavedSession {
|
||||
session: session.clone(),
|
||||
items: items.clone(),
|
||||
state: state.clone(),
|
||||
session_id: session.id,
|
||||
};
|
||||
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.read(true)
|
||||
.open(path)?;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
None,
|
||||
cwd,
|
||||
));
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok((Self { tx }, saved))
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
Ok(_) => rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
|
||||
Err(e) => {
|
||||
warn!("failed to send rollout shutdown command: {e}");
|
||||
Err(IoError::other(format!(
|
||||
"failed to send rollout shutdown command: {e}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout item: {e}")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,77 +185,3 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
mut meta: Option<SessionMeta>,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<()> {
|
||||
let mut writer = JsonlWriter { file };
|
||||
|
||||
// If we have a meta, collect git info asynchronously and write meta first
|
||||
if let Some(session_meta) = meta.take() {
|
||||
let git_info = collect_git_info(&cwd).await;
|
||||
let session_meta_with_git = SessionMetaWithGit {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
|
||||
// Write the SessionMeta as the first item in the file
|
||||
writer.write_line(&session_meta_with_git).await?;
|
||||
}
|
||||
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::Reasoning { .. } => {
|
||||
writer.write_line(&item).await?;
|
||||
}
|
||||
ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::UpdateState(state) => {
|
||||
#[derive(Serialize)]
|
||||
struct StateLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
state: &'a SessionStateSnapshot,
|
||||
}
|
||||
writer
|
||||
.write_line(&StateLine {
|
||||
record_type: "state",
|
||||
state: &state,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct JsonlWriter {
|
||||
file: tokio::fs::File,
|
||||
}
|
||||
|
||||
impl JsonlWriter {
|
||||
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
|
||||
let mut json = serde_json::to_string(item)?;
|
||||
json.push('\n');
|
||||
let _ = self.file.write_all(json.as_bytes()).await;
|
||||
self.file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
use shlex;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct ZshShell {
|
||||
shell_path: String,
|
||||
zshrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Shell {
|
||||
Zsh(ZshShell),
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Shell {
|
||||
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
|
||||
match self {
|
||||
Shell::Zsh(zsh) => {
|
||||
if !std::path::Path::new(&zsh.zshrc_path).exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut result = vec![zsh.shell_path.clone(), "-c".to_string()];
|
||||
if let Ok(joined) = shlex::try_join(command.iter().map(|s| s.as_str())) {
|
||||
result.push(format!("source {} && ({joined})", zsh.zshrc_path));
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
Some(result)
|
||||
}
|
||||
Shell::Unknown => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
use tokio::process::Command;
|
||||
use whoami;
|
||||
|
||||
let user = whoami::username();
|
||||
let home = format!("/Users/{user}");
|
||||
let output = Command::new("dscl")
|
||||
.args([".", "-read", &home, "UserShell"])
|
||||
.output()
|
||||
.await
|
||||
.ok();
|
||||
match output {
|
||||
Some(o) => {
|
||||
if !o.status.success() {
|
||||
return Shell::Unknown;
|
||||
}
|
||||
let stdout = String::from_utf8_lossy(&o.stdout);
|
||||
for line in stdout.lines() {
|
||||
if let Some(shell_path) = line.strip_prefix("UserShell: ") {
|
||||
if shell_path.ends_with("/zsh") {
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
zshrc_path: format!("{home}/.zshrc"),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Shell::Unknown
|
||||
}
|
||||
_ => Shell::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "macos")]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::process::Command;
|
||||
|
||||
#[tokio::test]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
async fn test_current_shell_detects_zsh() {
|
||||
let shell = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg("echo $SHELL")
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let home = std::env::var("HOME").unwrap();
|
||||
let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
|
||||
if shell_path.ends_with("/zsh") {
|
||||
assert_eq!(
|
||||
default_user_shell().await,
|
||||
Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
zshrc_path: format!("{home}/.zshrc",),
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_zshrc_not_exists() {
|
||||
let shell = Shell::Zsh(ZshShell {
|
||||
shell_path: "/bin/zsh".to_string(),
|
||||
zshrc_path: "/does/not/exist/.zshrc".to_string(),
|
||||
});
|
||||
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||
assert_eq!(actual_cmd, None);
|
||||
}
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_escaping_and_execution() {
|
||||
let shell_path = "/bin/zsh";
|
||||
|
||||
let cases = vec![
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![shell_path, "-c", "source ZSHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
vec!["bash", "-lc", "echo 'single' \"double\""],
|
||||
vec![
|
||||
shell_path,
|
||||
"-c",
|
||||
"source ZSHRC_PATH && (bash -lc \"echo 'single' \\\"double\\\"\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
];
|
||||
for (input, expected_cmd, expected_output) in cases {
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
// create a temp directory with a zshrc file in it
|
||||
let temp_home = tempfile::tempdir().unwrap();
|
||||
let zshrc_path = temp_home.path().join(".zshrc");
|
||||
std::fs::write(
|
||||
&zshrc_path,
|
||||
r#"
|
||||
set -x
|
||||
function myecho {
|
||||
echo 'It works!'
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let shell = Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
zshrc_path: zshrc_path.to_str().unwrap().to_string(),
|
||||
});
|
||||
|
||||
let actual_cmd = shell
|
||||
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||
let expected_cmd = expected_cmd
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
|
||||
.to_string()
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||
// Actually run the command and check output/exit code
|
||||
let output = process_exec_tool_call(
|
||||
ExecParams {
|
||||
command: actual_cmd.unwrap(),
|
||||
cwd: PathBuf::from(temp_home.path()),
|
||||
timeout_ms: None,
|
||||
env: HashMap::from([(
|
||||
"HOME".to_string(),
|
||||
temp_home.path().to_str().unwrap().to_string(),
|
||||
)]),
|
||||
},
|
||||
SandboxType::None,
|
||||
Arc::new(Notify::new()),
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
&None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
|
||||
if let Some(expected) = expected_output {
|
||||
assert_eq!(
|
||||
output.stdout, expected,
|
||||
"input: {input:?} output: {output:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,21 +64,3 @@ pub fn is_inside_git_repo(config: &Config) -> bool {
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// If `val` is a path to a readable file, return its trimmed contents.
|
||||
///
|
||||
/// - When `val` points to a file, this reads the file, trims leading/trailing
|
||||
/// whitespace and returns `Ok(Some(contents))` unless the trimmed contents are
|
||||
/// empty in which case it returns `Ok(None)`.
|
||||
/// - When `val` is not a file path, return `Ok(Some(val.to_string()))` so
|
||||
/// callers can treat the value as a literal string.
|
||||
pub fn maybe_read_file(val: &str) -> std::io::Result<Option<String>> {
|
||||
let p = std::path::Path::new(val);
|
||||
if p.is_file() {
|
||||
let s = std::fs::read_to_string(p)?;
|
||||
let s = s.trim().to_string();
|
||||
if s.is_empty() { Ok(None) } else { Ok(Some(s)) }
|
||||
} else {
|
||||
Ok(Some(val.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use assert_cmd::Command as AssertCommand;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tempfile::TempDir;
|
||||
@@ -122,7 +123,6 @@ async fn responses_api_stream_cli() {
|
||||
assert!(stdout.contains("fixture hello"));
|
||||
}
|
||||
|
||||
/// End-to-end: create a session (writes rollout), verify the file, then resume and confirm append.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_creates_and_checks_session_file() {
|
||||
// Honor sandbox network restrictions for CI parity with the other tests.
|
||||
@@ -170,66 +170,45 @@ async fn integration_creates_and_checks_session_file() {
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
// Wait for sessions dir to appear.
|
||||
// 5. Sessions are written asynchronously; wait briefly for the directory to appear.
|
||||
let sessions_dir = home.path().join("sessions");
|
||||
let dir_deadline = Instant::now() + Duration::from_secs(5);
|
||||
while !sessions_dir.exists() && Instant::now() < dir_deadline {
|
||||
let start = Instant::now();
|
||||
while !sessions_dir.exists() && start.elapsed() < Duration::from_secs(3) {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
assert!(sessions_dir.exists(), "sessions directory never appeared");
|
||||
|
||||
// Find the session file that contains `marker`.
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
let mut matching_path: Option<std::path::PathBuf> = None;
|
||||
while Instant::now() < deadline && matching_path.is_none() {
|
||||
for entry in WalkDir::new(&sessions_dir) {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
}
|
||||
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||
continue;
|
||||
}
|
||||
// 6. Scan all session files and find the one that contains our marker.
|
||||
let mut matching_files = vec![];
|
||||
for entry in WalkDir::new(&sessions_dir) {
|
||||
let entry = entry.unwrap();
|
||||
if entry.file_type().is_file() && entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||
let path = entry.path();
|
||||
let Ok(content) = std::fs::read_to_string(path) else {
|
||||
continue;
|
||||
};
|
||||
let content = std::fs::read_to_string(path).unwrap();
|
||||
let mut lines = content.lines();
|
||||
if lines.next().is_none() {
|
||||
continue;
|
||||
}
|
||||
// Skip SessionMeta (first line)
|
||||
let _ = lines.next();
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let item: serde_json::Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
|
||||
if let Some(c) = item.get("content") {
|
||||
if c.to_string().contains(&marker) {
|
||||
matching_path = Some(path.to_path_buf());
|
||||
let item: Value = serde_json::from_str(line).unwrap();
|
||||
if let Some("message") = item.get("type").and_then(|t| t.as_str()) {
|
||||
if let Some(content) = item.get("content") {
|
||||
if content.to_string().contains(&marker) {
|
||||
matching_files.push(path.to_owned());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if matching_path.is_none() {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
assert_eq!(
|
||||
matching_files.len(),
|
||||
1,
|
||||
"Expected exactly one session file containing the marker, found {}",
|
||||
matching_files.len()
|
||||
);
|
||||
let path = &matching_files[0];
|
||||
|
||||
let path = match matching_path {
|
||||
Some(p) => p,
|
||||
None => panic!("No session file containing the marker was found"),
|
||||
};
|
||||
|
||||
// Basic sanity checks on location and metadata.
|
||||
// 7. Verify directory structure: sessions/YYYY/MM/DD/filename.jsonl
|
||||
let rel = match path.strip_prefix(&sessions_dir) {
|
||||
Ok(r) => r,
|
||||
Err(_) => panic!("session file should live under sessions/"),
|
||||
@@ -258,6 +237,7 @@ async fn integration_creates_and_checks_session_file() {
|
||||
day.len() == 2 && day.chars().all(|c| c.is_ascii_digit()),
|
||||
"Day dir not zero-padded 2-digit numeric: {day}"
|
||||
);
|
||||
// Range checks (best-effort; won't fail on leading zeros)
|
||||
if let Ok(m) = month.parse::<u8>() {
|
||||
assert!((1..=12).contains(&m), "Month out of range: {m}");
|
||||
}
|
||||
@@ -265,32 +245,23 @@ async fn integration_creates_and_checks_session_file() {
|
||||
assert!((1..=31).contains(&d), "Day out of range: {d}");
|
||||
}
|
||||
|
||||
let content =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| panic!("Failed to read session file"));
|
||||
// 8. Parse SessionMeta line and basic sanity checks.
|
||||
let content = std::fs::read_to_string(path).unwrap();
|
||||
let mut lines = content.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or("missing session meta line")
|
||||
.unwrap_or_else(|_| panic!("missing session meta line"));
|
||||
let meta: serde_json::Value = serde_json::from_str(meta_line)
|
||||
.unwrap_or_else(|_| panic!("Failed to parse session meta line as JSON"));
|
||||
let meta: Value = serde_json::from_str(lines.next().unwrap()).unwrap();
|
||||
assert!(meta.get("id").is_some(), "SessionMeta missing id");
|
||||
assert!(
|
||||
meta.get("timestamp").is_some(),
|
||||
"SessionMeta missing timestamp"
|
||||
);
|
||||
|
||||
// 9. Confirm at least one message contains the marker.
|
||||
let mut found_message = false;
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(item) = serde_json::from_str::<serde_json::Value>(line) else {
|
||||
continue;
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
|
||||
if let Some(c) = item.get("content") {
|
||||
if c.to_string().contains(&marker) {
|
||||
let item: Value = serde_json::from_str(line).unwrap();
|
||||
if item.get("type").map(|t| t == "message").unwrap_or(false) {
|
||||
if let Some(content) = item.get("content") {
|
||||
if content.to_string().contains(&marker) {
|
||||
found_message = true;
|
||||
break;
|
||||
}
|
||||
@@ -301,184 +272,4 @@ async fn integration_creates_and_checks_session_file() {
|
||||
found_message,
|
||||
"No message found in session file containing the marker"
|
||||
);
|
||||
|
||||
// Second run: resume and append.
|
||||
let orig_len = content.lines().count();
|
||||
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
|
||||
let prompt2 = format!("echo {marker2}");
|
||||
// Cross‑platform safe resume override. On Windows, backslashes in a TOML string must be escaped
|
||||
// or the parse will fail and the raw literal (including quotes) may be preserved all the way down
|
||||
// to Config, which in turn breaks resume because the path is invalid. Normalize to forward slashes
|
||||
// to sidestep the issue.
|
||||
let resume_path_str = path.to_string_lossy().replace('\\', "/");
|
||||
let resume_override = format!("experimental_resume=\"{resume_path_str}\"");
|
||||
let mut cmd2 = AssertCommand::new("cargo");
|
||||
cmd2.arg("run")
|
||||
.arg("-p")
|
||||
.arg("codex-cli")
|
||||
.arg("--quiet")
|
||||
.arg("--")
|
||||
.arg("exec")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-c")
|
||||
.arg(&resume_override)
|
||||
.arg("-C")
|
||||
.arg(env!("CARGO_MANIFEST_DIR"))
|
||||
.arg(&prompt2);
|
||||
cmd2.env("CODEX_HOME", home.path())
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||
.env("OPENAI_BASE_URL", "http://unused.local");
|
||||
|
||||
let output2 = cmd2.output().unwrap();
|
||||
assert!(output2.status.success(), "resume codex-cli run failed");
|
||||
|
||||
// The rollout writer runs on a background async task; give it a moment to flush.
|
||||
let mut new_len = orig_len;
|
||||
let deadline = Instant::now() + Duration::from_secs(5);
|
||||
let mut content2 = String::new();
|
||||
while Instant::now() < deadline {
|
||||
if let Ok(c) = std::fs::read_to_string(&path) {
|
||||
let count = c.lines().count();
|
||||
if count > orig_len {
|
||||
content2 = c;
|
||||
new_len = count;
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
if content2.is_empty() {
|
||||
// last attempt
|
||||
content2 = std::fs::read_to_string(&path).unwrap();
|
||||
new_len = content2.lines().count();
|
||||
}
|
||||
assert!(new_len > orig_len, "rollout file did not grow after resume");
|
||||
assert!(content2.contains(&marker), "rollout lost original marker");
|
||||
assert!(
|
||||
content2.contains(&marker2),
|
||||
"rollout missing resumed marker"
|
||||
);
|
||||
}
|
||||
|
||||
/// Integration test to verify git info is collected and recorded in session files.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_git_info_unit_test() {
|
||||
// This test verifies git info collection works independently
|
||||
// without depending on the full CLI integration
|
||||
|
||||
// 1. Create temp directory for git repo
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let git_repo = temp_dir.path().to_path_buf();
|
||||
|
||||
// 2. Initialize a git repository with some content
|
||||
let init_output = std::process::Command::new("git")
|
||||
.args(["init"])
|
||||
.current_dir(&git_repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
assert!(init_output.status.success(), "git init failed");
|
||||
|
||||
// Configure git user (required for commits)
|
||||
std::process::Command::new("git")
|
||||
.args(["config", "user.name", "Integration Test"])
|
||||
.current_dir(&git_repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
std::process::Command::new("git")
|
||||
.args(["config", "user.email", "test@example.com"])
|
||||
.current_dir(&git_repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// Create a test file and commit it
|
||||
let test_file = git_repo.join("test.txt");
|
||||
std::fs::write(&test_file, "integration test content").unwrap();
|
||||
|
||||
std::process::Command::new("git")
|
||||
.args(["add", "."])
|
||||
.current_dir(&git_repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let commit_output = std::process::Command::new("git")
|
||||
.args(["commit", "-m", "Integration test commit"])
|
||||
.current_dir(&git_repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
assert!(commit_output.status.success(), "git commit failed");
|
||||
|
||||
// Create a branch to test branch detection
|
||||
std::process::Command::new("git")
|
||||
.args(["checkout", "-b", "integration-test-branch"])
|
||||
.current_dir(&git_repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// Add a remote to test repository URL detection
|
||||
std::process::Command::new("git")
|
||||
.args([
|
||||
"remote",
|
||||
"add",
|
||||
"origin",
|
||||
"https://github.com/example/integration-test.git",
|
||||
])
|
||||
.current_dir(&git_repo)
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// 3. Test git info collection directly
|
||||
let git_info = codex_core::git_info::collect_git_info(&git_repo).await;
|
||||
|
||||
// 4. Verify git info is present and contains expected data
|
||||
assert!(git_info.is_some(), "Git info should be collected");
|
||||
|
||||
let git_info = git_info.unwrap();
|
||||
|
||||
// Check that we have a commit hash
|
||||
assert!(
|
||||
git_info.commit_hash.is_some(),
|
||||
"Git info should contain commit_hash"
|
||||
);
|
||||
let commit_hash = git_info.commit_hash.as_ref().unwrap();
|
||||
assert_eq!(commit_hash.len(), 40, "Commit hash should be 40 characters");
|
||||
assert!(
|
||||
commit_hash.chars().all(|c| c.is_ascii_hexdigit()),
|
||||
"Commit hash should be hexadecimal"
|
||||
);
|
||||
|
||||
// Check that we have the correct branch
|
||||
assert!(git_info.branch.is_some(), "Git info should contain branch");
|
||||
let branch = git_info.branch.as_ref().unwrap();
|
||||
assert_eq!(
|
||||
branch, "integration-test-branch",
|
||||
"Branch should match what we created"
|
||||
);
|
||||
|
||||
// Check that we have the repository URL
|
||||
assert!(
|
||||
git_info.repository_url.is_some(),
|
||||
"Git info should contain repository_url"
|
||||
);
|
||||
let repo_url = git_info.repository_url.as_ref().unwrap();
|
||||
assert_eq!(
|
||||
repo_url, "https://github.com/example/integration-test.git",
|
||||
"Repository URL should match what we configured"
|
||||
);
|
||||
|
||||
println!("✅ Git info collection test passed!");
|
||||
println!(" Commit: {commit_hash}");
|
||||
println!(" Branch: {branch}");
|
||||
println!(" Repo: {repo_url}");
|
||||
|
||||
// 5. Test serialization to ensure it works in SessionMeta
|
||||
let serialized = serde_json::to_string(&git_info).unwrap();
|
||||
let deserialized: codex_core::git_info::GitInfo = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(git_info.commit_hash, deserialized.commit_hash);
|
||||
assert_eq!(git_info.branch, deserialized.branch);
|
||||
assert_eq!(git_info.repository_url, deserialized.repository_url);
|
||||
|
||||
println!("✅ Git info serialization test passed!");
|
||||
}
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
use codex_core::Codex;
|
||||
use codex_core::CodexSpawnOk;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_session_id_and_model_headers_in_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[("originator".to_string(), "codex_cli_rs".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) =
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_))).await
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let current_session_id = Some(session_id.to_string());
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.headers.get("session_id").unwrap();
|
||||
let originator = request.headers.get("originator").unwrap();
|
||||
|
||||
assert!(current_session_id.is_some());
|
||||
assert_eq!(
|
||||
request_body.to_str().unwrap(),
|
||||
current_session_id.as_ref().unwrap()
|
||||
);
|
||||
assert_eq!(originator.to_str().unwrap(), "codex_cli_rs");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_base_instructions_override_in_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
|
||||
config.base_instructions = Some("test instructions".to_string());
|
||||
config.model_provider = model_provider;
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
assert!(
|
||||
request_body["instructions"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("test instructions")
|
||||
);
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
[package]
|
||||
name = "core_test_support"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
codex-core = { path = "../.." }
|
||||
serde_json = "1"
|
||||
tempfile = "3"
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
@@ -20,15 +20,15 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::CodexSpawnOk;
|
||||
use codex_core::error::CodexErr;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
mod test_support;
|
||||
use tempfile::TempDir;
|
||||
use test_support::load_default_config_for_test;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::timeout;
|
||||
|
||||
@@ -49,8 +49,7 @@ async fn spawn_codex() -> Result<Codex, CodexErr> {
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider.request_max_retries = Some(2);
|
||||
config.model_provider.stream_max_retries = Some(2);
|
||||
let CodexSpawnOk { codex: agent, .. } =
|
||||
Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?;
|
||||
let (agent, _init_id) = Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?;
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
|
||||
165
codex-rs/core/tests/previous_response_id.rs
Normal file
165
codex-rs/core/tests/previous_response_id.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
mod test_support;
|
||||
use serde_json::Value;
|
||||
use tempfile::TempDir;
|
||||
use test_support::load_default_config_for_test;
|
||||
use test_support::load_sse_fixture_with_id;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::Match;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Request;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
/// Matcher asserting that JSON body has NO `previous_response_id` field.
|
||||
struct NoPrevId;
|
||||
|
||||
impl Match for NoPrevId {
|
||||
fn matches(&self, req: &Request) -> bool {
|
||||
serde_json::from_slice::<Value>(&req.body)
|
||||
.map(|v| v.get("previous_response_id").is_none())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Matcher asserting that JSON body HAS a `previous_response_id` field.
|
||||
struct HasPrevId;
|
||||
|
||||
impl Match for HasPrevId {
|
||||
fn matches(&self, req: &Request) -> bool {
|
||||
serde_json::from_slice::<Value>(&req.body)
|
||||
.map(|v| v.get("previous_response_id").is_some())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn keeps_previous_response_id_between_tasks() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(NoPrevId)
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Second request – MUST include `previous_response_id`.
|
||||
let second = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp2"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(HasPrevId)
|
||||
.respond_with(second)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure retry behavior explicitly to avoid mutating process-wide
|
||||
// environment variables.
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
// disable retries so we don't get duplicate calls in this test
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
// Task 1 – triggers first request (no previous_response_id)
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for TaskComplete
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Task 2 – should include `previous_response_id` (triggers second request)
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "again".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for TaskComplete or error
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
match ev.msg {
|
||||
EventMsg::TaskComplete(_) => break,
|
||||
EventMsg::Error(ErrorEvent { message }) => {
|
||||
panic!("unexpected error: {message}")
|
||||
}
|
||||
_ => {
|
||||
// Ignore other events.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,16 +4,16 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::CodexSpawnOk;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
mod test_support;
|
||||
use tempfile::TempDir;
|
||||
use test_support::load_default_config_for_test;
|
||||
use test_support::load_sse_fixture;
|
||||
use test_support::load_sse_fixture_with_id;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
@@ -95,7 +95,7 @@ async fn retries_on_early_close() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c).await.unwrap();
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
// Helpers shared by the integration tests. These are located inside the
|
||||
// `tests/` tree on purpose so they never become part of the public API surface
|
||||
// of the `codex-core` crate.
|
||||
|
||||
use tempfile::TempDir;
|
||||
|
||||
use codex_core::config::Config;
|
||||
@@ -26,6 +30,7 @@ pub fn load_default_config_for_test(codex_home: &TempDir) -> Config {
|
||||
/// with only a `type` field results in an event with no `data:` section. This
|
||||
/// makes it trivial to extend the fixtures as OpenAI adds new event kinds or
|
||||
/// fields.
|
||||
#[allow(dead_code)]
|
||||
pub fn load_sse_fixture(path: impl AsRef<std::path::Path>) -> String {
|
||||
let events: Vec<serde_json::Value> =
|
||||
serde_json::from_reader(std::fs::File::open(path).expect("read fixture"))
|
||||
@@ -50,6 +55,7 @@ pub fn load_sse_fixture(path: impl AsRef<std::path::Path>) -> String {
|
||||
/// fixture template with the supplied identifier before parsing. This lets a
|
||||
/// single JSON template be reused by multiple tests that each need a unique
|
||||
/// `response_id`.
|
||||
#[allow(dead_code)]
|
||||
pub fn load_sse_fixture_with_id(path: impl AsRef<std::path::Path>, id: &str) -> String {
|
||||
let raw = std::fs::read_to_string(path).expect("read fixture template");
|
||||
let replaced = raw.replace("__ID__", id);
|
||||
@@ -70,23 +76,3 @@ pub fn load_sse_fixture_with_id(path: impl AsRef<std::path::Path>, id: &str) ->
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn wait_for_event<F>(
|
||||
codex: &codex_core::Codex,
|
||||
mut predicate: F,
|
||||
) -> codex_core::protocol::EventMsg
|
||||
where
|
||||
F: FnMut(&codex_core::protocol::EventMsg) -> bool,
|
||||
{
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("stream ended unexpectedly");
|
||||
if predicate(&ev.msg) {
|
||||
return ev.msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,13 +18,13 @@ workspace = true
|
||||
anyhow = "1"
|
||||
chrono = "0.4.40"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-arg0 = { path = "../arg0" }
|
||||
codex-core = { path = "../core" }
|
||||
codex-common = { path = "../common", features = [
|
||||
"cli",
|
||||
"elapsed",
|
||||
"sandbox_summary",
|
||||
] }
|
||||
codex-linux-sandbox = { path = "../linux-sandbox" }
|
||||
owo-colors = "4.2.0"
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
@@ -37,8 +37,3 @@ tokio = { version = "1", features = [
|
||||
] }
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
predicates = "3"
|
||||
tempfile = "3.13.0"
|
||||
|
||||
@@ -63,40 +63,6 @@ pub struct Cli {
|
||||
/// if `-` is used), instructions are read from stdin.
|
||||
#[arg(value_name = "PROMPT")]
|
||||
pub prompt: Option<String>,
|
||||
|
||||
/// Override the built-in system prompt (base instructions).
|
||||
///
|
||||
/// If the value looks like a path to an existing file, the contents of the
|
||||
/// file are used. Otherwise, the value itself is used verbatim as the
|
||||
/// instructions string.
|
||||
#[arg(long = "experimental-instructions")]
|
||||
pub experimental_instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Cli;
|
||||
use clap::CommandFactory;
|
||||
|
||||
#[test]
|
||||
fn help_includes_file_behavior_for_experimental_instructions() {
|
||||
let mut cmd = Cli::command();
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
assert!(cmd.write_long_help(&mut buf).is_ok(), "help should render");
|
||||
let help = match String::from_utf8(buf) {
|
||||
Ok(s) => s,
|
||||
Err(e) => panic!("invalid utf8: {e}"),
|
||||
};
|
||||
assert!(help.contains("Override the built-in system prompt (base instructions)."));
|
||||
assert!(help.contains(
|
||||
"If the value looks like a path to an existing file, the contents of the file are used."
|
||||
));
|
||||
assert!(
|
||||
help.contains(
|
||||
"Otherwise, the value itself is used verbatim as the instructions string."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
|
||||
|
||||
@@ -1,53 +1,25 @@
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_common::summarize_sandbox_policy;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::model_supports_reasoning_summaries;
|
||||
use codex_core::protocol::Event;
|
||||
|
||||
pub(crate) enum CodexStatus {
|
||||
Running,
|
||||
InitiateShutdown,
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
pub(crate) trait EventProcessor {
|
||||
/// Print summary of effective configuration and user prompt.
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str);
|
||||
|
||||
/// Handle a single event emitted by the agent.
|
||||
fn process_event(&mut self, event: Event) -> CodexStatus;
|
||||
fn process_event(&mut self, event: Event);
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub(crate) enum PromptOrigin {
|
||||
File(PathBuf),
|
||||
Literal,
|
||||
}
|
||||
|
||||
pub(crate) fn create_config_summary_entries(
|
||||
config: &Config,
|
||||
prompt_origin: Option<&PromptOrigin>,
|
||||
) -> Vec<(&'static str, String)> {
|
||||
pub(crate) fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, String)> {
|
||||
let mut entries = vec![
|
||||
("workdir", config.cwd.display().to_string()),
|
||||
("model", config.model.clone()),
|
||||
("provider", config.model_provider_id.clone()),
|
||||
("approval", config.approval_policy.to_string()),
|
||||
("approval", format!("{:?}", config.approval_policy)),
|
||||
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
|
||||
];
|
||||
if let Some(origin) = prompt_origin {
|
||||
let prompt_val = match origin {
|
||||
PromptOrigin::Literal => "experimental".to_string(),
|
||||
PromptOrigin::File(path) => path
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| path.display().to_string()),
|
||||
};
|
||||
entries.push(("prompt_origin", prompt_val));
|
||||
}
|
||||
if config.model_provider.wire_api == WireApi::Responses
|
||||
&& model_supports_reasoning_summaries(config)
|
||||
{
|
||||
@@ -63,89 +35,3 @@ pub(crate) fn create_config_summary_entries(
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
pub(crate) fn handle_last_message(
|
||||
last_agent_message: Option<&str>,
|
||||
last_message_path: Option<&Path>,
|
||||
) {
|
||||
match (last_message_path, last_agent_message) {
|
||||
(Some(path), Some(msg)) => write_last_message_file(msg, Some(path)),
|
||||
(Some(path), None) => {
|
||||
write_last_message_file("", Some(path));
|
||||
eprintln!(
|
||||
"Warning: no last agent message; wrote empty content to {}",
|
||||
path.display()
|
||||
);
|
||||
}
|
||||
(None, _) => eprintln!("Warning: no file to write last message to."),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_last_message_file(contents: &str, last_message_path: Option<&Path>) {
|
||||
if let Some(path) = last_message_path {
|
||||
if let Err(e) = std::fs::write(path, contents) {
|
||||
eprintln!("Failed to write last message file {path:?}: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use std::collections::HashMap;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn minimal_config() -> Config {
|
||||
let cwd = match TempDir::new() {
|
||||
Ok(t) => t,
|
||||
Err(e) => panic!("tempdir error: {e}"),
|
||||
};
|
||||
let codex_home = match TempDir::new() {
|
||||
Ok(t) => t,
|
||||
Err(e) => panic!("tempdir error: {e}"),
|
||||
};
|
||||
let cfg = ConfigToml {
|
||||
..Default::default()
|
||||
};
|
||||
let overrides = ConfigOverrides {
|
||||
cwd: Some(cwd.path().to_path_buf()),
|
||||
..Default::default()
|
||||
};
|
||||
match Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
overrides,
|
||||
codex_home.path().to_path_buf(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => panic!("config error: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn entries_include_prompt_origin_experimental_for_literal_origin() {
|
||||
let mut cfg = minimal_config();
|
||||
cfg.base_instructions = Some("hello".to_string());
|
||||
let entries = create_config_summary_entries(&cfg, Some(&PromptOrigin::Literal));
|
||||
let map: HashMap<_, _> = entries.into_iter().collect();
|
||||
assert_eq!(
|
||||
map.get("prompt_origin").cloned(),
|
||||
Some("experimental".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn entries_include_prompt_origin_filename_for_file_origin() {
|
||||
let mut cfg = minimal_config();
|
||||
cfg.base_instructions = Some("hello".to_string());
|
||||
let path = PathBuf::from("/tmp/custom_instructions.txt");
|
||||
let entries = create_config_summary_entries(&cfg, Some(&PromptOrigin::File(path.clone())));
|
||||
let map: HashMap<_, _> = entries.into_iter().collect();
|
||||
assert_eq!(
|
||||
map.get("prompt_origin").cloned(),
|
||||
Some("custom_instructions.txt".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,21 +15,16 @@ use codex_core::protocol::McpToolCallEndEvent;
|
||||
use codex_core::protocol::PatchApplyBeginEvent;
|
||||
use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::PromptOrigin;
|
||||
use crate::event_processor::create_config_summary_entries;
|
||||
use crate::event_processor::handle_last_message;
|
||||
|
||||
/// This should be configurable. When used in CI, users may not want to impose
|
||||
/// a limit so they can see the full transcript.
|
||||
@@ -59,17 +54,10 @@ pub(crate) struct EventProcessorWithHumanOutput {
|
||||
show_agent_reasoning: bool,
|
||||
answer_started: bool,
|
||||
reasoning_started: bool,
|
||||
last_message_path: Option<PathBuf>,
|
||||
prompt_origin: Option<PromptOrigin>,
|
||||
}
|
||||
|
||||
impl EventProcessorWithHumanOutput {
|
||||
pub(crate) fn create_with_ansi(
|
||||
with_ansi: bool,
|
||||
config: &Config,
|
||||
last_message_path: Option<PathBuf>,
|
||||
prompt_origin: Option<PromptOrigin>,
|
||||
) -> Self {
|
||||
pub(crate) fn create_with_ansi(with_ansi: bool, config: &Config) -> Self {
|
||||
let call_id_to_command = HashMap::new();
|
||||
let call_id_to_patch = HashMap::new();
|
||||
let call_id_to_tool_call = HashMap::new();
|
||||
@@ -89,8 +77,6 @@ impl EventProcessorWithHumanOutput {
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
last_message_path,
|
||||
prompt_origin,
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
@@ -107,8 +93,6 @@ impl EventProcessorWithHumanOutput {
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
last_message_path,
|
||||
prompt_origin,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -155,7 +139,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
VERSION
|
||||
);
|
||||
|
||||
let entries = create_config_summary_entries(config, self.prompt_origin.as_ref());
|
||||
let entries = create_config_summary_entries(config);
|
||||
|
||||
for (key, value) in entries {
|
||||
println!("{} {}", format!("{key}:").style(self.bold), value);
|
||||
@@ -174,7 +158,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
);
|
||||
}
|
||||
|
||||
fn process_event(&mut self, event: Event) -> CodexStatus {
|
||||
fn process_event(&mut self, event: Event) {
|
||||
let Event { id: _, msg } = event;
|
||||
match msg {
|
||||
EventMsg::Error(ErrorEvent { message }) => {
|
||||
@@ -184,16 +168,9 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => {
|
||||
ts_println!(self, "{}", message.style(self.dimmed));
|
||||
}
|
||||
EventMsg::TaskStarted => {
|
||||
EventMsg::TaskStarted | EventMsg::TaskComplete(_) => {
|
||||
// Ignore.
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
handle_last_message(
|
||||
last_agent_message.as_deref(),
|
||||
self.last_message_path.as_deref(),
|
||||
);
|
||||
return CodexStatus::InitiateShutdown;
|
||||
}
|
||||
EventMsg::TokenCount(TokenUsage { total_tokens, .. }) => {
|
||||
ts_println!(self, "tokens used: {total_tokens}");
|
||||
}
|
||||
@@ -208,7 +185,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => {
|
||||
if !self.show_agent_reasoning {
|
||||
return CodexStatus::Running;
|
||||
return;
|
||||
}
|
||||
if !self.reasoning_started {
|
||||
ts_println!(
|
||||
@@ -521,9 +498,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
EventMsg::GetHistoryEntryResponse(_) => {
|
||||
// Currently ignored in exec output.
|
||||
}
|
||||
EventMsg::ShutdownComplete => return CodexStatus::Shutdown,
|
||||
}
|
||||
CodexStatus::Running
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,35 +1,24 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::PromptOrigin;
|
||||
use crate::event_processor::create_config_summary_entries;
|
||||
use crate::event_processor::handle_last_message;
|
||||
|
||||
pub(crate) struct EventProcessorWithJsonOutput {
|
||||
last_message_path: Option<PathBuf>,
|
||||
prompt_origin: Option<PromptOrigin>,
|
||||
}
|
||||
pub(crate) struct EventProcessorWithJsonOutput;
|
||||
|
||||
impl EventProcessorWithJsonOutput {
|
||||
pub fn new(last_message_path: Option<PathBuf>, prompt_origin: Option<PromptOrigin>) -> Self {
|
||||
Self {
|
||||
last_message_path,
|
||||
prompt_origin,
|
||||
}
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl EventProcessor for EventProcessorWithJsonOutput {
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
let entries = create_config_summary_entries(config, self.prompt_origin.as_ref())
|
||||
let entries = create_config_summary_entries(config)
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.to_string(), value))
|
||||
.collect::<HashMap<String, String>>();
|
||||
@@ -44,25 +33,15 @@ impl EventProcessor for EventProcessorWithJsonOutput {
|
||||
println!("{prompt_json}");
|
||||
}
|
||||
|
||||
fn process_event(&mut self, event: Event) -> CodexStatus {
|
||||
fn process_event(&mut self, event: Event) {
|
||||
match event.msg {
|
||||
EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_) => {
|
||||
// Suppress streaming events in JSON mode.
|
||||
CodexStatus::Running
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
handle_last_message(
|
||||
last_agent_message.as_deref(),
|
||||
self.last_message_path.as_deref(),
|
||||
);
|
||||
CodexStatus::InitiateShutdown
|
||||
}
|
||||
EventMsg::ShutdownComplete => CodexStatus::Shutdown,
|
||||
_ => {
|
||||
if let Ok(line) = serde_json::to_string(&event) {
|
||||
println!("{line}");
|
||||
}
|
||||
CodexStatus::Running
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,13 +5,12 @@ mod event_processor_with_json_output;
|
||||
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::event_processor::PromptOrigin;
|
||||
pub use cli::Cli;
|
||||
use codex_core::codex_wrapper::CodexConversation;
|
||||
use codex_core::codex_wrapper::{self};
|
||||
use codex_core::codex_wrapper;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config_types::SandboxMode;
|
||||
@@ -22,7 +21,6 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::util::is_inside_git_repo;
|
||||
use codex_core::util::maybe_read_file;
|
||||
use event_processor_with_human_output::EventProcessorWithHumanOutput;
|
||||
use event_processor_with_json_output::EventProcessorWithJsonOutput;
|
||||
use tracing::debug;
|
||||
@@ -30,7 +28,6 @@ use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
@@ -47,38 +44,9 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
json: json_mode,
|
||||
sandbox_mode: sandbox_mode_cli_arg,
|
||||
prompt,
|
||||
experimental_instructions,
|
||||
config_overrides,
|
||||
} = cli;
|
||||
|
||||
// Determine how to describe experimental instructions in the summary and
|
||||
// prepare the effective base instructions. If the flag points at a file,
|
||||
// read its contents; otherwise use the value verbatim.
|
||||
let mut prompt_origin = match experimental_instructions.as_deref() {
|
||||
Some(val) => {
|
||||
let p = std::path::Path::new(val);
|
||||
if p.is_file() {
|
||||
Some(PromptOrigin::File(p.to_path_buf()))
|
||||
} else {
|
||||
Some(PromptOrigin::Literal)
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let experimental_instructions = match experimental_instructions {
|
||||
Some(val) => match maybe_read_file(&val) {
|
||||
Ok(Some(contents)) => Some(contents),
|
||||
Ok(None) => None,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to read --experimental-instructions file: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
let has_experimental = experimental_instructions.is_some();
|
||||
|
||||
// Determine the prompt based on CLI arg and/or stdin.
|
||||
let prompt = match prompt {
|
||||
Some(p) if p != "-" => p,
|
||||
@@ -142,7 +110,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
model_provider: None,
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: experimental_instructions,
|
||||
};
|
||||
// Parse `-c` overrides.
|
||||
let cli_kv_overrides = match config_overrides.parse_overrides() {
|
||||
@@ -154,21 +121,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
|
||||
if !has_experimental {
|
||||
prompt_origin = None;
|
||||
}
|
||||
|
||||
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
|
||||
Box::new(EventProcessorWithJsonOutput::new(
|
||||
last_message_file.clone(),
|
||||
prompt_origin.clone(),
|
||||
))
|
||||
Box::new(EventProcessorWithJsonOutput::new())
|
||||
} else {
|
||||
Box::new(EventProcessorWithHumanOutput::create_with_ansi(
|
||||
stdout_with_ansi,
|
||||
&config,
|
||||
last_message_file.clone(),
|
||||
prompt_origin,
|
||||
))
|
||||
};
|
||||
|
||||
@@ -195,14 +153,9 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
.with_writer(std::io::stderr)
|
||||
.try_init();
|
||||
|
||||
let CodexConversation {
|
||||
codex: codex_wrapper,
|
||||
session_configured,
|
||||
ctrl_c,
|
||||
..
|
||||
} = codex_wrapper::init_codex(config).await?;
|
||||
let (codex_wrapper, event, ctrl_c) = codex_wrapper::init_codex(config).await?;
|
||||
let codex = Arc::new(codex_wrapper);
|
||||
info!("Codex initialized with event: {session_configured:?}");
|
||||
info!("Codex initialized with event: {event:?}");
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Event>();
|
||||
{
|
||||
@@ -270,67 +223,40 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
|
||||
// Run the loop until the task is complete.
|
||||
while let Some(event) = rx.recv().await {
|
||||
let shutdown: CodexStatus = event_processor.process_event(event);
|
||||
match shutdown {
|
||||
CodexStatus::Running => continue,
|
||||
CodexStatus::InitiateShutdown => {
|
||||
codex.submit(Op::Shutdown).await?;
|
||||
}
|
||||
CodexStatus::Shutdown => {
|
||||
break;
|
||||
let (is_last_event, last_assistant_message) = match &event.msg {
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
(true, last_agent_message.clone())
|
||||
}
|
||||
_ => (false, None),
|
||||
};
|
||||
event_processor.process_event(event);
|
||||
if is_last_event {
|
||||
handle_last_message(last_assistant_message, last_message_file.as_deref())?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use codex_core::util::maybe_read_file;
|
||||
use std::fs;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn maybe_read_file_returns_literal_for_non_path() {
|
||||
let res = match maybe_read_file("You are a helpful assistant.") {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!("error: {e}"),
|
||||
};
|
||||
assert_eq!(res, Some("You are a helpful assistant.".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_read_file_reads_and_trims_file_contents() {
|
||||
let tf = match NamedTempFile::new() {
|
||||
Ok(t) => t,
|
||||
Err(e) => panic!("tempfile: {e}"),
|
||||
};
|
||||
if let Err(e) = fs::write(tf.path(), " Hello world\n") {
|
||||
panic!("write temp file: {e}");
|
||||
fn handle_last_message(
|
||||
last_agent_message: Option<String>,
|
||||
last_message_file: Option<&Path>,
|
||||
) -> std::io::Result<()> {
|
||||
match (last_agent_message, last_message_file) {
|
||||
(Some(last_agent_message), Some(last_message_file)) => {
|
||||
// Last message and a file to write to.
|
||||
std::fs::write(last_message_file, last_agent_message)?;
|
||||
}
|
||||
let path_s = tf.path().to_string_lossy().to_string();
|
||||
let res = match maybe_read_file(&path_s) {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!("should read file successfully: {e}"),
|
||||
};
|
||||
assert_eq!(res, Some("Hello world".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_read_file_empty_file_returns_none() {
|
||||
let tf = match NamedTempFile::new() {
|
||||
Ok(t) => t,
|
||||
Err(e) => panic!("tempfile: {e}"),
|
||||
};
|
||||
if let Err(e) = fs::write(tf.path(), " \n\t ") {
|
||||
panic!("write temp file: {e}");
|
||||
(None, Some(last_message_file)) => {
|
||||
eprintln!(
|
||||
"Warning: No last message to write to file: {}",
|
||||
last_message_file.to_string_lossy()
|
||||
);
|
||||
}
|
||||
(_, None) => {
|
||||
// No last message and no file to write to.
|
||||
}
|
||||
let path_s = tf.path().to_string_lossy().to_string();
|
||||
let res = match maybe_read_file(&path_s) {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!("should read file successfully: {e}"),
|
||||
};
|
||||
assert_eq!(res, None);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
//! This allows us to ship a completely separate set of functionality as part
|
||||
//! of the `codex-exec` binary.
|
||||
use clap::Parser;
|
||||
use codex_arg0::arg0_dispatch_or_else;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_exec::Cli;
|
||||
use codex_exec::run_main;
|
||||
@@ -25,7 +24,7 @@ struct TopCli {
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
codex_linux_sandbox::run_with_sandbox(|codex_linux_sandbox_exe| async move {
|
||||
let top_cli = TopCli::parse();
|
||||
// Merge root-level overrides into inner CLI struct so downstream logic remains unchanged.
|
||||
let mut inner = top_cli.inner;
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// While we may add an `apply-patch` subcommand to the `codex` CLI multitool
|
||||
/// at some point, we must ensure that the smaller `codex-exec` CLI can still
|
||||
/// emulate the `apply_patch` CLI.
|
||||
#[test]
|
||||
fn test_standalone_exec_cli_can_use_apply_patch() -> anyhow::Result<()> {
|
||||
let tmp = tempdir()?;
|
||||
let relative_path = "source.txt";
|
||||
let absolute_path = tmp.path().join(relative_path);
|
||||
fs::write(&absolute_path, "original content\n")?;
|
||||
|
||||
Command::cargo_bin("codex-exec")
|
||||
.context("should find binary for codex-exec")?
|
||||
.arg("--codex-run-as-apply-patch")
|
||||
.arg(
|
||||
r#"*** Begin Patch
|
||||
*** Update File: source.txt
|
||||
@@
|
||||
-original content
|
||||
+modified by apply_patch
|
||||
*** End Patch"#,
|
||||
)
|
||||
.current_dir(tmp.path())
|
||||
.assert()
|
||||
.success()
|
||||
.stdout("Success. Updated the following files:\nM source.txt\n")
|
||||
.stderr(predicates::str::is_empty());
|
||||
assert_eq!(
|
||||
fs::read_to_string(absolute_path)?,
|
||||
"modified by apply_patch\n"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
@@ -14,16 +14,13 @@ path = "src/lib.rs"
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-common = { path = "../common", features = ["cli"] }
|
||||
codex-core = { path = "../core" }
|
||||
libc = "0.2.172"
|
||||
landlock = "0.4.1"
|
||||
seccompiler = "0.5.0"
|
||||
tokio = { version = "1", features = ["rt-multi-thread"] }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dev-dependencies]
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
@@ -32,3 +29,8 @@ tokio = { version = "1", features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
libc = "0.2.172"
|
||||
landlock = "0.4.1"
|
||||
seccompiler = "0.5.0"
|
||||
|
||||
@@ -4,8 +4,57 @@ mod landlock;
|
||||
mod linux_run_main;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn run_main() -> ! {
|
||||
linux_run_main::run_main();
|
||||
pub use linux_run_main::run_main;
|
||||
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Helper that consolidates the common boilerplate found in several Codex
|
||||
/// binaries (`codex`, `codex-exec`, `codex-tui`) around dispatching to the
|
||||
/// `codex-linux-sandbox` sub-command.
|
||||
///
|
||||
/// When the current executable is invoked through the hard-link or alias
|
||||
/// named `codex-linux-sandbox` we *directly* execute [`run_main`](crate::run_main)
|
||||
/// (which never returns). Otherwise we:
|
||||
/// 1. Construct a Tokio multi-thread runtime.
|
||||
/// 2. Derive the path to the current executable (so children can re-invoke
|
||||
/// the sandbox) when running on Linux.
|
||||
/// 3. Execute the provided async `main_fn` inside that runtime, forwarding
|
||||
/// any error.
|
||||
///
|
||||
/// This function eliminates duplicated code across the various `main.rs`
|
||||
/// entry-points.
|
||||
pub fn run_with_sandbox<F, Fut>(main_fn: F) -> anyhow::Result<()>
|
||||
where
|
||||
F: FnOnce(Option<PathBuf>) -> Fut,
|
||||
Fut: Future<Output = anyhow::Result<()>>,
|
||||
{
|
||||
use std::path::Path;
|
||||
|
||||
// Determine if we were invoked via the special alias.
|
||||
let argv0 = std::env::args().next().unwrap_or_default();
|
||||
let exe_name = Path::new(&argv0)
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if exe_name == "codex-linux-sandbox" {
|
||||
// Safety: [`run_main`] never returns.
|
||||
crate::run_main();
|
||||
}
|
||||
|
||||
// Regular invocation – create a Tokio runtime and execute the provided
|
||||
// async entry-point.
|
||||
let runtime = tokio::runtime::Runtime::new()?;
|
||||
runtime.block_on(async move {
|
||||
let codex_linux_sandbox_exe: Option<PathBuf> = if cfg!(target_os = "linux") {
|
||||
std::env::current_exe().ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
main_fn(codex_linux_sandbox_exe).await
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
//! program. The utility connects, issues a `tools/list` request and prints the
|
||||
//! server's response as pretty JSON.
|
||||
|
||||
use std::ffi::OsString;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -38,7 +37,7 @@ async fn main() -> Result<()> {
|
||||
.try_init();
|
||||
|
||||
// Collect command-line arguments excluding the program name itself.
|
||||
let mut args: Vec<OsString> = std::env::args_os().skip(1).collect();
|
||||
let mut args: Vec<String> = std::env::args().skip(1).collect();
|
||||
|
||||
if args.is_empty() || args[0] == "--help" || args[0] == "-h" {
|
||||
eprintln!("Usage: mcp-client <program> [args..]\n\nExample: mcp-client codex-mcp-server");
|
||||
@@ -58,12 +57,10 @@ async fn main() -> Result<()> {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".to_string()),
|
||||
},
|
||||
protocol_version: MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
//! issue requests and receive strongly-typed results.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
@@ -83,8 +82,8 @@ impl McpClient {
|
||||
/// Caller is responsible for sending the `initialize` request. See
|
||||
/// [`initialize`](Self::initialize) for details.
|
||||
pub async fn new_stdio_client(
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
program: String,
|
||||
args: Vec<String>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
) -> std::io::Result<Self> {
|
||||
let mut child = Command::new(program)
|
||||
|
||||
@@ -16,13 +16,12 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
codex-arg0 = { path = "../arg0" }
|
||||
codex-core = { path = "../core" }
|
||||
codex-linux-sandbox = { path = "../linux-sandbox" }
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
schemars = "0.8.22"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
toml = "0.9"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
||||
@@ -33,12 +32,6 @@ tokio = { version = "1", features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
mcp_test_support = { path = "tests/common" }
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -7,16 +7,15 @@ use mcp_types::ToolInputSchema;
|
||||
use schemars::JsonSchema;
|
||||
use schemars::r#gen::SchemaSettings;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::json_to_toml::json_to_toml;
|
||||
|
||||
/// Client-supplied configuration for a `codex` tool-call.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[derive(Debug, Clone, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct CodexToolCallParam {
|
||||
pub(crate) struct CodexToolCallParam {
|
||||
/// The *initial user prompt* to start the Codex conversation.
|
||||
pub prompt: String,
|
||||
|
||||
@@ -46,17 +45,13 @@ pub struct CodexToolCallParam {
|
||||
/// CODEX_HOME/config.toml.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub config: Option<HashMap<String, serde_json::Value>>,
|
||||
|
||||
/// The set of instructions to use instead of the default ones.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub base_instructions: Option<String>,
|
||||
}
|
||||
|
||||
/// Custom enum mirroring [`AskForApproval`], but has an extra dependency on
|
||||
/// [`JsonSchema`].
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Clone, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum CodexToolCallApprovalPolicy {
|
||||
pub(crate) enum CodexToolCallApprovalPolicy {
|
||||
Untrusted,
|
||||
OnFailure,
|
||||
Never,
|
||||
@@ -74,9 +69,9 @@ impl From<CodexToolCallApprovalPolicy> for AskForApproval {
|
||||
|
||||
/// Custom enum mirroring [`SandboxMode`] from config_types.rs, but with
|
||||
/// `JsonSchema` support.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Clone, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum CodexToolCallSandboxMode {
|
||||
pub(crate) enum CodexToolCallSandboxMode {
|
||||
ReadOnly,
|
||||
WorkspaceWrite,
|
||||
DangerFullAccess,
|
||||
@@ -113,10 +108,7 @@ pub(crate) fn create_tool_for_codex_tool_call_param() -> Tool {
|
||||
|
||||
Tool {
|
||||
name: "codex".to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
input_schema: tool_input_schema,
|
||||
// TODO(mbolin): This should be defined.
|
||||
output_schema: None,
|
||||
description: Some(
|
||||
"Run a Codex session. Accepts configuration parameters matching the Codex Config struct.".to_string(),
|
||||
),
|
||||
@@ -139,7 +131,6 @@ impl CodexToolCallParam {
|
||||
approval_policy,
|
||||
sandbox,
|
||||
config: cli_overrides,
|
||||
base_instructions,
|
||||
} = self;
|
||||
|
||||
// Build the `ConfigOverrides` recognised by codex-core.
|
||||
@@ -151,7 +142,6 @@ impl CodexToolCallParam {
|
||||
sandbox_mode: sandbox.map(Into::into),
|
||||
model_provider: None,
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions,
|
||||
};
|
||||
|
||||
let cli_overrides = cli_overrides
|
||||
@@ -166,47 +156,6 @@ impl CodexToolCallParam {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CodexToolCallReplyParam {
|
||||
/// The *session id* for this conversation.
|
||||
pub session_id: String,
|
||||
|
||||
/// The *next user prompt* to continue the Codex conversation.
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
/// Builds a `Tool` definition for the `codex-reply` tool-call.
|
||||
pub(crate) fn create_tool_for_codex_tool_call_reply_param() -> Tool {
|
||||
let schema = SchemaSettings::draft2019_09()
|
||||
.with(|s| {
|
||||
s.inline_subschemas = true;
|
||||
s.option_add_null_type = false;
|
||||
})
|
||||
.into_generator()
|
||||
.into_root_schema_for::<CodexToolCallReplyParam>();
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema_value =
|
||||
serde_json::to_value(&schema).expect("Codex reply tool schema should serialise to JSON");
|
||||
|
||||
let tool_input_schema =
|
||||
serde_json::from_value::<ToolInputSchema>(schema_value).unwrap_or_else(|e| {
|
||||
panic!("failed to create Tool from schema: {e}");
|
||||
});
|
||||
|
||||
Tool {
|
||||
name: "codex-reply".to_string(),
|
||||
title: Some("Codex Reply".to_string()),
|
||||
input_schema: tool_input_schema,
|
||||
output_schema: None,
|
||||
description: Some(
|
||||
"Continue a Codex session by providing the session id and prompt.".to_string(),
|
||||
),
|
||||
annotations: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -230,7 +179,6 @@ mod tests {
|
||||
let tool_json = serde_json::to_value(&tool).expect("tool serializes");
|
||||
let expected_tool_json = serde_json::json!({
|
||||
"name": "codex",
|
||||
"title": "Codex",
|
||||
"description": "Run a Codex session. Accepts configuration parameters matching the Codex Config struct.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
@@ -274,10 +222,6 @@ mod tests {
|
||||
"description": "The *initial user prompt* to start the Codex conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
"base-instructions": {
|
||||
"description": "The set of instructions to use instead of the default ones.",
|
||||
"type": "string"
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"prompt"
|
||||
@@ -286,34 +230,4 @@ mod tests {
|
||||
});
|
||||
assert_eq!(expected_tool_json, tool_json);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_codex_tool_reply_json_schema() {
|
||||
let tool = create_tool_for_codex_tool_call_reply_param();
|
||||
#[expect(clippy::expect_used)]
|
||||
let tool_json = serde_json::to_value(&tool).expect("tool serializes");
|
||||
let expected_tool_json = serde_json::json!({
|
||||
"description": "Continue a Codex session by providing the session id and prompt.",
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"description": "The *next user prompt* to continue the Codex conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
"sessionId": {
|
||||
"description": "The *session id* for this conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"prompt",
|
||||
"sessionId",
|
||||
],
|
||||
"type": "object",
|
||||
},
|
||||
"name": "codex-reply",
|
||||
"title": "Codex Reply",
|
||||
});
|
||||
assert_eq!(expected_tool_json, tool_json);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,35 +2,33 @@
|
||||
//! Tokio task. Separated from `message_processor.rs` to keep that file small
|
||||
//! and to make future feature-growth easier to manage.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::codex_wrapper::CodexConversation;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::CallToolResultContent;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
|
||||
pub(crate) const INVALID_PARAMS_ERROR_CODE: i64 = -32602;
|
||||
/// Convert a Codex [`Event`] to an MCP notification.
|
||||
fn codex_event_to_notification(event: &Event) -> JSONRPCMessage {
|
||||
#[expect(clippy::expect_used)]
|
||||
JSONRPCMessage::Notification(mcp_types::JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method: "codex/event".into(),
|
||||
params: Some(serde_json::to_value(event).expect("Event must serialize")),
|
||||
})
|
||||
}
|
||||
|
||||
/// Run a complete Codex session and stream events back to the client.
|
||||
///
|
||||
@@ -40,43 +38,33 @@ pub async fn run_codex_tool_session(
|
||||
id: RequestId,
|
||||
initial_prompt: String,
|
||||
config: CodexConfig,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
outgoing: Sender<JSONRPCMessage>,
|
||||
) {
|
||||
let CodexConversation {
|
||||
codex,
|
||||
session_configured,
|
||||
session_id,
|
||||
..
|
||||
} = match init_codex(config).await {
|
||||
let (codex, first_event, _ctrl_c) = match init_codex(config).await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Failed to start Codex session: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing.send_response(id.clone(), result.into()).await;
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let codex = Arc::new(codex);
|
||||
|
||||
// update the session map so we can retrieve the session in a reply, and then drop it, since
|
||||
// we no longer need it for this function
|
||||
session_map.lock().await.insert(session_id, codex.clone());
|
||||
drop(session_map);
|
||||
|
||||
outgoing
|
||||
.send_event_as_notification(
|
||||
&session_configured,
|
||||
Some(OutgoingNotificationMeta::new(Some(id.clone()))),
|
||||
)
|
||||
// Send initial SessionConfigured event.
|
||||
let _ = outgoing
|
||||
.send(codex_event_to_notification(&first_event))
|
||||
.await;
|
||||
|
||||
// Use the original MCP request ID as the `sub_id` for the Codex submission so that
|
||||
@@ -86,12 +74,9 @@ pub async fn run_codex_tool_session(
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.insert(id.clone(), session_id);
|
||||
|
||||
let submission = Submission {
|
||||
id: sub_id.clone(),
|
||||
id: sub_id,
|
||||
op: Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: initial_prompt.clone(),
|
||||
@@ -101,143 +86,86 @@ pub async fn run_codex_tool_session(
|
||||
|
||||
if let Err(e) = codex.submit_with_id(submission).await {
|
||||
tracing::error!("Failed to submit initial prompt: {e}");
|
||||
// unregister the id so we don't keep it in the map
|
||||
running_requests_id_to_codex_uuid.lock().await.remove(&id);
|
||||
return;
|
||||
}
|
||||
|
||||
run_codex_tool_session_inner(codex, outgoing, id, running_requests_id_to_codex_uuid).await;
|
||||
}
|
||||
|
||||
pub async fn run_codex_tool_session_reply(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
prompt: String,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
session_id: Uuid,
|
||||
) {
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id.clone(), session_id);
|
||||
if let Err(e) = codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text { text: prompt }],
|
||||
})
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to submit user input: {e}");
|
||||
// unregister the id so we don't keep it in the map
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.remove(&request_id);
|
||||
return;
|
||||
}
|
||||
|
||||
run_codex_tool_session_inner(
|
||||
codex,
|
||||
outgoing,
|
||||
request_id,
|
||||
running_requests_id_to_codex_uuid,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn run_codex_tool_session_inner(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
|
||||
// Stream events until the task needs to pause for user interaction or
|
||||
// completes.
|
||||
loop {
|
||||
match codex.next_event().await {
|
||||
Ok(event) => {
|
||||
outgoing
|
||||
.send_event_as_notification(
|
||||
&event,
|
||||
Some(OutgoingNotificationMeta::new(Some(request_id.clone()))),
|
||||
)
|
||||
.await;
|
||||
let _ = outgoing.send(codex_event_to_notification(&event)).await;
|
||||
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
command,
|
||||
cwd,
|
||||
call_id,
|
||||
reason: _,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
match &event.msg {
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
last_agent_message = Some(message.clone());
|
||||
}
|
||||
EventMsg::Error(err_event) => {
|
||||
// Return a response to conclude the tool call when the Codex session reports an error (e.g., interruption).
|
||||
let result = json!({
|
||||
"error": err_event.message,
|
||||
});
|
||||
outgoing.send_response(request_id.clone(), result).await;
|
||||
break;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
let text = match last_agent_message {
|
||||
Some(msg) => msg.clone(),
|
||||
None => "".to_string(),
|
||||
};
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text,
|
||||
text: "EXEC_APPROVAL_REQUIRED".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing
|
||||
.send_response(request_id.clone(), result.into())
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: "PATCH_APPROVAL_REQUIRED".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
};
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: _,
|
||||
}) => {
|
||||
let result = if let Some(msg) = last_agent_message {
|
||||
CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: msg,
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
}
|
||||
} else {
|
||||
CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: String::new(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
}
|
||||
};
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
// unregister the id so we don't keep it in the map
|
||||
running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.remove(&request_id);
|
||||
break;
|
||||
}
|
||||
EventMsg::SessionConfigured(_) => {
|
||||
@@ -249,10 +177,8 @@ async fn run_codex_tool_session_inner(
|
||||
EventMsg::AgentReasoningDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::TaskStarted
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::TaskStarted
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
@@ -262,8 +188,7 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
| EventMsg::GetHistoryEntryResponse(_) => {
|
||||
// For now, we do not do anything extra for these
|
||||
// events. Note that
|
||||
// send(codex_event_to_notification(&event)) above has
|
||||
@@ -275,18 +200,19 @@ async fn run_codex_tool_session_inner(
|
||||
}
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Codex runtime error: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
// TODO(mbolin): Could present the error in a more
|
||||
// structured way.
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing
|
||||
.send_response(request_id.clone(), result.into())
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE;
|
||||
|
||||
/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the
|
||||
/// `params` field of an [`ElicitRequest`].
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecApprovalElicitRequestParams {
|
||||
// These fields are required so that `params`
|
||||
// conforms to ElicitRequestParams.
|
||||
pub message: String,
|
||||
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
|
||||
// These are additional fields the client can use to
|
||||
// correlate the request with the codex tool call.
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
pub codex_call_id: String,
|
||||
pub codex_command: Vec<String>,
|
||||
pub codex_cwd: PathBuf,
|
||||
}
|
||||
|
||||
// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See:
|
||||
// - https://github.com/modelcontextprotocol/modelcontextprotocol/blob/f962dc1780fa5eed7fb7c8a0232f1fc83ef220cd/schema/2025-06-18/schema.json#L617-L636
|
||||
// - https://modelcontextprotocol.io/specification/draft/client/elicitation#protocol-messages
|
||||
// It should have "action" and "content" fields.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ExecApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_exec_approval_request(
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
outgoing: Arc<crate::outgoing_message::OutgoingMessageSender>,
|
||||
codex: Arc<Codex>,
|
||||
request_id: RequestId,
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
call_id: String,
|
||||
) {
|
||||
let escaped_command =
|
||||
shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "));
|
||||
let message = format!(
|
||||
"Allow Codex to run `{escaped_command}` in `{cwd}`?",
|
||||
cwd = cwd.to_string_lossy()
|
||||
);
|
||||
|
||||
let params = ExecApprovalElicitRequestParams {
|
||||
message,
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "exec-approval".to_string(),
|
||||
codex_mcp_tool_call_id: tool_call_id.clone(),
|
||||
codex_event_id: event_id.clone(),
|
||||
codex_call_id: call_id,
|
||||
codex_command: command,
|
||||
codex_cwd: cwd,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
let message = format!("Failed to serialize ExecApprovalElicitRequestParams: {err}");
|
||||
error!("{message}");
|
||||
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id.clone(),
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_PARAMS_ERROR_CODE,
|
||||
message,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let on_response = outgoing
|
||||
.send_request(ElicitRequest::METHOD, Some(params_json))
|
||||
.await;
|
||||
|
||||
// Listen for the response on a separate task so we don't block the main agent loop.
|
||||
{
|
||||
let codex = codex.clone();
|
||||
let event_id = event_id.clone();
|
||||
tokio::spawn(async move {
|
||||
on_exec_approval_response(event_id, on_response, codex).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_exec_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<Codex>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("request failed: {err:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Try to deserialize `value` and then make the appropriate call to `codex`.
|
||||
let response = serde_json::from_value::<ExecApprovalResponse>(value).unwrap_or_else(|err| {
|
||||
error!("failed to deserialize ExecApprovalResponse: {err}");
|
||||
// If we cannot deserialize the response, we deny the request to be
|
||||
// conservative.
|
||||
ExecApprovalResponse {
|
||||
decision: ReviewDecision::Denied,
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: event_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit ExecApproval: {err}");
|
||||
}
|
||||
}
|
||||
@@ -13,26 +13,13 @@ use tokio::sync::mpsc;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
mod codex_tool_config;
|
||||
mod codex_tool_runner;
|
||||
mod exec_approval;
|
||||
mod json_to_toml;
|
||||
mod message_processor;
|
||||
mod outgoing_message;
|
||||
mod patch_approval;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
pub use crate::codex_tool_config::CodexToolCallParam;
|
||||
pub use crate::codex_tool_config::CodexToolCallReplyParam;
|
||||
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
|
||||
pub use crate::exec_approval::ExecApprovalResponse;
|
||||
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
|
||||
pub use crate::patch_approval::PatchApprovalResponse;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage – 128 messages should be
|
||||
@@ -44,12 +31,11 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
// control the log level with `RUST_LOG`.
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
// Set up channels.
|
||||
let (incoming_tx, mut incoming_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
|
||||
|
||||
// Task: read from stdin, push to `incoming_tx`.
|
||||
let stdin_reader_handle = tokio::spawn({
|
||||
@@ -77,15 +63,16 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
|
||||
// Task: process incoming messages.
|
||||
let processor_handle = tokio::spawn({
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
let mut processor = MessageProcessor::new(outgoing_message_sender, codex_linux_sandbox_exe);
|
||||
let mut processor = MessageProcessor::new(outgoing_tx.clone(), codex_linux_sandbox_exe);
|
||||
async move {
|
||||
while let Some(msg) = incoming_rx.recv().await {
|
||||
match msg {
|
||||
JSONRPCMessage::Request(r) => processor.process_request(r).await,
|
||||
JSONRPCMessage::Response(r) => processor.process_response(r).await,
|
||||
JSONRPCMessage::Notification(n) => processor.process_notification(n).await,
|
||||
JSONRPCMessage::Request(r) => processor.process_request(r),
|
||||
JSONRPCMessage::Response(r) => processor.process_response(r),
|
||||
JSONRPCMessage::Notification(n) => processor.process_notification(n),
|
||||
JSONRPCMessage::BatchRequest(b) => processor.process_batch_request(b),
|
||||
JSONRPCMessage::Error(e) => processor.process_error(e),
|
||||
JSONRPCMessage::BatchResponse(b) => processor.process_batch_response(b),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,8 +83,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
// Task: write outgoing messages to stdout.
|
||||
let stdout_writer_handle = tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(outgoing_message) = outgoing_rx.recv().await {
|
||||
let msg: JSONRPCMessage = outgoing_message.into();
|
||||
while let Some(msg) = outgoing_rx.recv().await {
|
||||
match serde_json::to_string(&msg) {
|
||||
Ok(json) => {
|
||||
if let Err(e) = stdout.write_all(json.as_bytes()).await {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use codex_arg0::arg0_dispatch_or_else;
|
||||
use codex_mcp_server::run_main;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
codex_linux_sandbox::run_with_sandbox(|codex_linux_sandbox_exe| async move {
|
||||
run_main(codex_linux_sandbox_exe).await?;
|
||||
Ok(())
|
||||
})
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex_tool_config::CodexToolCallParam;
|
||||
use crate::codex_tool_config::CodexToolCallReplyParam;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_param;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::Submission;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::CallToolResultContent;
|
||||
use mcp_types::ClientRequest;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCBatchRequest;
|
||||
use mcp_types::JSONRPCBatchResponse;
|
||||
use mcp_types::JSONRPCError;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
@@ -27,35 +24,30 @@ use mcp_types::ServerCapabilitiesTools;
|
||||
use mcp_types::ServerNotification;
|
||||
use mcp_types::TextContent;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub(crate) struct MessageProcessor {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
outgoing: mpsc::Sender<JSONRPCMessage>,
|
||||
initialized: bool,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
/// Create a new `MessageProcessor`, retaining a handle to the outgoing
|
||||
/// `Sender` so handlers can enqueue messages to be written to stdout.
|
||||
pub(crate) fn new(
|
||||
outgoing: OutgoingMessageSender,
|
||||
outgoing: mpsc::Sender<JSONRPCMessage>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self {
|
||||
outgoing: Arc::new(outgoing),
|
||||
outgoing,
|
||||
initialized: false,
|
||||
codex_linux_sandbox_exe,
|
||||
session_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
pub(crate) fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
// Hold on to the ID so we can respond.
|
||||
let request_id = request.id.clone();
|
||||
|
||||
@@ -70,10 +62,10 @@ impl MessageProcessor {
|
||||
// Dispatch to a dedicated handler for each request type.
|
||||
match client_request {
|
||||
ClientRequest::InitializeRequest(params) => {
|
||||
self.handle_initialize(request_id, params).await;
|
||||
self.handle_initialize(request_id, params);
|
||||
}
|
||||
ClientRequest::PingRequest(params) => {
|
||||
self.handle_ping(request_id, params).await;
|
||||
self.handle_ping(request_id, params);
|
||||
}
|
||||
ClientRequest::ListResourcesRequest(params) => {
|
||||
self.handle_list_resources(params);
|
||||
@@ -97,10 +89,10 @@ impl MessageProcessor {
|
||||
self.handle_get_prompt(params);
|
||||
}
|
||||
ClientRequest::ListToolsRequest(params) => {
|
||||
self.handle_list_tools(request_id, params).await;
|
||||
self.handle_list_tools(request_id, params);
|
||||
}
|
||||
ClientRequest::CallToolRequest(params) => {
|
||||
self.handle_call_tool(request_id, params).await;
|
||||
self.handle_call_tool(request_id, params);
|
||||
}
|
||||
ClientRequest::SetLevelRequest(params) => {
|
||||
self.handle_set_level(params);
|
||||
@@ -112,14 +104,12 @@ impl MessageProcessor {
|
||||
}
|
||||
|
||||
/// Handle a standalone JSON-RPC response originating from the peer.
|
||||
pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) {
|
||||
pub(crate) fn process_response(&mut self, response: JSONRPCResponse) {
|
||||
tracing::info!("<- response: {:?}", response);
|
||||
let JSONRPCResponse { id, result, .. } = response;
|
||||
self.outgoing.notify_client_response(id, result).await
|
||||
}
|
||||
|
||||
/// Handle a fire-and-forget JSON-RPC notification.
|
||||
pub(crate) async fn process_notification(&mut self, notification: JSONRPCNotification) {
|
||||
pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) {
|
||||
let server_notification = match ServerNotification::try_from(notification) {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
@@ -132,7 +122,7 @@ impl MessageProcessor {
|
||||
// handler so additional logic can be implemented incrementally.
|
||||
match server_notification {
|
||||
ServerNotification::CancelledNotification(params) => {
|
||||
self.handle_cancelled_notification(params).await;
|
||||
self.handle_cancelled_notification(params);
|
||||
}
|
||||
ServerNotification::ProgressNotification(params) => {
|
||||
self.handle_progress_notification(params);
|
||||
@@ -155,12 +145,42 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a batch of requests and/or notifications.
|
||||
pub(crate) fn process_batch_request(&mut self, batch: JSONRPCBatchRequest) {
|
||||
tracing::info!("<- batch request containing {} item(s)", batch.len());
|
||||
for item in batch {
|
||||
match item {
|
||||
mcp_types::JSONRPCBatchRequestItem::JSONRPCRequest(req) => {
|
||||
self.process_request(req);
|
||||
}
|
||||
mcp_types::JSONRPCBatchRequestItem::JSONRPCNotification(note) => {
|
||||
self.process_notification(note);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an error object received from the peer.
|
||||
pub(crate) fn process_error(&mut self, err: JSONRPCError) {
|
||||
tracing::error!("<- error: {:?}", err);
|
||||
}
|
||||
|
||||
async fn handle_initialize(
|
||||
/// Handle a batch of responses/errors.
|
||||
pub(crate) fn process_batch_response(&mut self, batch: JSONRPCBatchResponse) {
|
||||
tracing::info!("<- batch response containing {} item(s)", batch.len());
|
||||
for item in batch {
|
||||
match item {
|
||||
mcp_types::JSONRPCBatchResponseItem::JSONRPCResponse(resp) => {
|
||||
self.process_response(resp);
|
||||
}
|
||||
mcp_types::JSONRPCBatchResponseItem::JSONRPCError(err) => {
|
||||
self.process_error(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_initialize(
|
||||
&mut self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::InitializeRequest as ModelContextProtocolRequest>::Params,
|
||||
@@ -169,12 +189,19 @@ impl MessageProcessor {
|
||||
|
||||
if self.initialized {
|
||||
// Already initialised: send JSON-RPC error response.
|
||||
let error = JSONRPCErrorError {
|
||||
code: -32600, // Invalid Request
|
||||
message: "initialize called more than once".to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(id, error).await;
|
||||
let error_msg = JSONRPCMessage::Error(JSONRPCError {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
error: JSONRPCErrorError {
|
||||
code: -32600, // Invalid Request
|
||||
message: "initialize called more than once".to_string(),
|
||||
data: None,
|
||||
},
|
||||
});
|
||||
|
||||
if let Err(e) = self.outgoing.try_send(error_msg) {
|
||||
tracing::error!("Failed to send initialization error: {e}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -196,34 +223,38 @@ impl MessageProcessor {
|
||||
protocol_version: params.protocol_version.clone(),
|
||||
server_info: mcp_types::Implementation {
|
||||
name: "codex-mcp-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
version: mcp_types::MCP_SCHEMA_VERSION.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
self.send_response::<mcp_types::InitializeRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::InitializeRequest>(id, result);
|
||||
}
|
||||
|
||||
async fn send_response<T>(&self, id: RequestId, result: T::Result)
|
||||
fn send_response<T>(&self, id: RequestId, result: T::Result)
|
||||
where
|
||||
T: ModelContextProtocolRequest,
|
||||
{
|
||||
// result has `Serialized` instance so should never fail
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let result = serde_json::to_value(result).unwrap();
|
||||
self.outgoing.send_response(id, result).await;
|
||||
let response = JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result: serde_json::to_value(result).unwrap(),
|
||||
});
|
||||
|
||||
if let Err(e) = self.outgoing.try_send(response) {
|
||||
tracing::error!("Failed to send response: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_ping(
|
||||
fn handle_ping(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::PingRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::info!("ping -> params: {:?}", params);
|
||||
let result = json!({});
|
||||
self.send_response::<mcp_types::PingRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::PingRequest>(id, result);
|
||||
}
|
||||
|
||||
fn handle_list_resources(
|
||||
@@ -276,25 +307,21 @@ impl MessageProcessor {
|
||||
tracing::info!("prompts/get -> params: {:?}", params);
|
||||
}
|
||||
|
||||
async fn handle_list_tools(
|
||||
fn handle_list_tools(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::ListToolsRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::trace!("tools/list -> {params:?}");
|
||||
let result = ListToolsResult {
|
||||
tools: vec![
|
||||
create_tool_for_codex_tool_call_param(),
|
||||
create_tool_for_codex_tool_call_reply_param(),
|
||||
],
|
||||
tools: vec![create_tool_for_codex_tool_call_param()],
|
||||
next_cursor: None,
|
||||
};
|
||||
|
||||
self.send_response::<mcp_types::ListToolsRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::ListToolsRequest>(id, result);
|
||||
}
|
||||
|
||||
async fn handle_call_tool(
|
||||
fn handle_call_tool(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
@@ -302,36 +329,28 @@ impl MessageProcessor {
|
||||
tracing::info!("tools/call -> params: {:?}", params);
|
||||
let CallToolRequestParams { name, arguments } = params;
|
||||
|
||||
match name.as_str() {
|
||||
"codex" => self.handle_tool_call_codex(id, arguments).await,
|
||||
"codex-reply" => {
|
||||
self.handle_tool_call_codex_session_reply(id, arguments)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Unknown tool '{name}'"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
}
|
||||
// We only support the "codex" tool for now.
|
||||
if name != "codex" {
|
||||
// Tool not found – return error result so the LLM can react.
|
||||
let result = CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Unknown tool '{name}'"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option<serde_json::Value>) {
|
||||
let (initial_prompt, config): (String, CodexConfig) = match arguments {
|
||||
Some(json_val) => match serde_json::from_value::<CodexToolCallParam>(json_val) {
|
||||
Ok(tool_cfg) => match tool_cfg.into_config(self.codex_linux_sandbox_exe.clone()) {
|
||||
Ok(cfg) => cfg,
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!(
|
||||
"Failed to load Codex configuration from overrides: {e}"
|
||||
@@ -339,31 +358,27 @@ impl MessageProcessor {
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse configuration for Codex tool: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text:
|
||||
"Missing arguments for codex tool-call; the `prompt` field is required."
|
||||
@@ -371,147 +386,21 @@ impl MessageProcessor {
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Clone outgoing and session map to move into async task.
|
||||
// Clone outgoing sender to move into async task.
|
||||
let outgoing = self.outgoing.clone();
|
||||
let session_map = self.session_map.clone();
|
||||
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
// Spawn an async task to handle the Codex session so that we do not
|
||||
// block the synchronous message-processing loop.
|
||||
task::spawn(async move {
|
||||
// Run the Codex session and stream events back to the client.
|
||||
crate::codex_tool_runner::run_codex_tool_session(
|
||||
id,
|
||||
initial_prompt,
|
||||
config,
|
||||
outgoing,
|
||||
session_map,
|
||||
running_requests_id_to_codex_uuid,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
async fn handle_tool_call_codex_session_reply(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
arguments: Option<serde_json::Value>,
|
||||
) {
|
||||
tracing::info!("tools/call -> params: {:?}", arguments);
|
||||
|
||||
// parse arguments
|
||||
let CodexToolCallReplyParam { session_id, prompt } = match arguments {
|
||||
Some(json_val) => match serde_json::from_value::<CodexToolCallReplyParam>(json_val) {
|
||||
Ok(params) => params,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse Codex tool call reply parameters: {e}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse configuration for Codex tool: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
tracing::error!(
|
||||
"Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required."
|
||||
);
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: "Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required.".to_owned(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let session_id = match Uuid::parse_str(&session_id) {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse session_id: {e}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse session_id: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// load codex from session map
|
||||
let session_map_mutex = Arc::clone(&self.session_map);
|
||||
|
||||
// Clone outgoing and session map to move into async task.
|
||||
let outgoing = self.outgoing.clone();
|
||||
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
let codex = {
|
||||
let session_map = session_map_mutex.lock().await;
|
||||
match session_map.get(&session_id).cloned() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Session not found for session_id: {session_id}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing
|
||||
.send_response(request_id, serde_json::to_value(result).unwrap_or_default())
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Spawn the long-running reply handler.
|
||||
tokio::spawn({
|
||||
let codex = codex.clone();
|
||||
let outgoing = outgoing.clone();
|
||||
let prompt = prompt.clone();
|
||||
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
|
||||
|
||||
async move {
|
||||
crate::codex_tool_runner::run_codex_tool_session_reply(
|
||||
codex,
|
||||
outgoing,
|
||||
request_id,
|
||||
prompt,
|
||||
running_requests_id_to_codex_uuid,
|
||||
session_id,
|
||||
)
|
||||
crate::codex_tool_runner::run_codex_tool_session(id, initial_prompt, config, outgoing)
|
||||
.await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -533,58 +422,11 @@ impl MessageProcessor {
|
||||
// Notification handlers
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
async fn handle_cancelled_notification(
|
||||
fn handle_cancelled_notification(
|
||||
&self,
|
||||
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
|
||||
) {
|
||||
let request_id = params.request_id;
|
||||
// Create a stable string form early for logging and submission id.
|
||||
let request_id_string = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
};
|
||||
|
||||
// Obtain the session_id while holding the first lock, then release.
|
||||
let session_id = {
|
||||
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
|
||||
match map_guard.get(&request_id) {
|
||||
Some(id) => *id, // Uuid is Copy
|
||||
None => {
|
||||
tracing::warn!("Session not found for request_id: {}", request_id_string);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
tracing::info!("session_id: {session_id}");
|
||||
|
||||
// Obtain the Codex Arc while holding the session_map lock, then release.
|
||||
let codex_arc = {
|
||||
let sessions_guard = self.session_map.lock().await;
|
||||
match sessions_guard.get(&session_id) {
|
||||
Some(codex) => Arc::clone(codex),
|
||||
None => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Submit interrupt to Codex.
|
||||
let err = codex_arc
|
||||
.submit_with_id(Submission {
|
||||
id: request_id_string,
|
||||
op: codex_core::protocol::Op::Interrupt,
|
||||
})
|
||||
.await;
|
||||
if let Err(e) = err {
|
||||
tracing::error!("Failed to submit interrupt to Codex: {e}");
|
||||
return;
|
||||
}
|
||||
// unregister the id so we don't keep it in the map
|
||||
self.running_requests_id_to_codex_uuid
|
||||
.lock()
|
||||
.await
|
||||
.remove(&request_id);
|
||||
tracing::info!("notifications/cancelled -> params: {:?}", params);
|
||||
}
|
||||
|
||||
fn handle_progress_notification(
|
||||
|
||||
@@ -1,331 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_core::protocol::Event;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCError;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::Result;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::warn;
|
||||
|
||||
/// Sends messages to the client and manages request callbacks.
|
||||
pub(crate) struct OutgoingMessageSender {
|
||||
next_request_id: AtomicI64,
|
||||
sender: mpsc::Sender<OutgoingMessage>,
|
||||
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
|
||||
}
|
||||
|
||||
impl OutgoingMessageSender {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingMessage>) -> Self {
|
||||
Self {
|
||||
next_request_id: AtomicI64::new(0),
|
||||
sender,
|
||||
request_id_to_callback: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> oneshot::Receiver<Result> {
|
||||
let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let outgoing_message_id = id.clone();
|
||||
let (tx_approve, rx_approve) = oneshot::channel();
|
||||
{
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.insert(id, tx_approve);
|
||||
}
|
||||
|
||||
let outgoing_message = OutgoingMessage::Request(OutgoingRequest {
|
||||
id: outgoing_message_id,
|
||||
method: method.to_string(),
|
||||
params,
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
rx_approve
|
||||
}
|
||||
|
||||
pub(crate) async fn notify_client_response(&self, id: RequestId, result: Result) {
|
||||
let entry = {
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove_entry(&id)
|
||||
};
|
||||
|
||||
match entry {
|
||||
Some((id, sender)) => {
|
||||
if let Err(err) = sender.send(result) {
|
||||
warn!("could not notify callback for {id:?} due to: {err:?}");
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("could not find callback for {id:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response(&self, id: RequestId, result: Result) {
|
||||
let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_event_as_notification(
|
||||
&self,
|
||||
event: &Event,
|
||||
meta: Option<OutgoingNotificationMeta>,
|
||||
) {
|
||||
#[allow(clippy::expect_used)]
|
||||
let event_json = serde_json::to_value(event).expect("Event must serialize");
|
||||
|
||||
let params = if let Ok(params) = serde_json::to_value(OutgoingNotificationParams {
|
||||
meta,
|
||||
event: event_json.clone(),
|
||||
}) {
|
||||
params
|
||||
} else {
|
||||
warn!("Failed to serialize event as OutgoingNotificationParams");
|
||||
event_json
|
||||
};
|
||||
|
||||
let outgoing_message = OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: "codex/event".to_string(),
|
||||
params: Some(params.clone()),
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
|
||||
self.send_event_as_notification_new_schema(event, Some(params.clone()))
|
||||
.await;
|
||||
}
|
||||
|
||||
// should be backwards compatible.
|
||||
// it will replace send_event_as_notification eventually.
|
||||
async fn send_event_as_notification_new_schema(
|
||||
&self,
|
||||
event: &Event,
|
||||
params: Option<serde_json::Value>,
|
||||
) {
|
||||
let outgoing_message = OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: event.msg.to_string(),
|
||||
params,
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Outgoing message from the server to the client.
|
||||
pub(crate) enum OutgoingMessage {
|
||||
Request(OutgoingRequest),
|
||||
Notification(OutgoingNotification),
|
||||
Response(OutgoingResponse),
|
||||
Error(OutgoingError),
|
||||
}
|
||||
|
||||
impl From<OutgoingMessage> for JSONRPCMessage {
|
||||
fn from(val: OutgoingMessage) -> Self {
|
||||
use OutgoingMessage::*;
|
||||
match val {
|
||||
Request(OutgoingRequest { id, method, params }) => {
|
||||
JSONRPCMessage::Request(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
Notification(OutgoingNotification { method, params }) => {
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
Response(OutgoingResponse { id, result }) => {
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result,
|
||||
})
|
||||
}
|
||||
Error(OutgoingError { id, error }) => JSONRPCMessage::Error(JSONRPCError {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
error,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingRequest {
|
||||
pub id: RequestId,
|
||||
pub method: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingNotification {
|
||||
pub method: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingNotificationParams {
|
||||
#[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<OutgoingNotificationMeta>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub event: serde_json::Value,
|
||||
}
|
||||
|
||||
// Additional mcp-specific data to be added to a [`codex_core::protocol::Event`] as notification.params._meta
|
||||
// MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic#meta
|
||||
// Typescript Schema: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/0695a497eb50a804fc0e88c18a93a21a675d6b3e/schema/2025-06-18/schema.ts
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub(crate) struct OutgoingNotificationMeta {
|
||||
pub request_id: Option<RequestId>,
|
||||
}
|
||||
|
||||
impl OutgoingNotificationMeta {
|
||||
pub(crate) fn new(request_id: Option<RequestId>) -> Self {
|
||||
Self { request_id }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingResponse {
|
||||
pub id: RequestId,
|
||||
pub result: Result,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingError {
|
||||
pub error: JSONRPCErrorError,
|
||||
pub id: RequestId,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_send_event_as_notification() {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(2);
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id: Uuid::new_v4(),
|
||||
model: "gpt-4o".to_string(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
}),
|
||||
};
|
||||
|
||||
outgoing_message_sender
|
||||
.send_event_as_notification(&event, None)
|
||||
.await;
|
||||
|
||||
let result = outgoing_rx.recv().await.unwrap();
|
||||
let OutgoingMessage::Notification(OutgoingNotification { method, params }) = result else {
|
||||
panic!("expected Notification for first message");
|
||||
};
|
||||
assert_eq!(method, "codex/event");
|
||||
|
||||
let Ok(expected_params) = serde_json::to_value(&event) else {
|
||||
panic!("Event must serialize");
|
||||
};
|
||||
assert_eq!(params, Some(expected_params.clone()));
|
||||
|
||||
let result2 = outgoing_rx.recv().await.unwrap();
|
||||
let OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: method2,
|
||||
params: params2,
|
||||
}) = result2
|
||||
else {
|
||||
panic!("expected Notification for second message");
|
||||
};
|
||||
assert_eq!(method2, event.msg.to_string());
|
||||
assert_eq!(params2, Some(expected_params));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_send_event_as_notification_with_meta() {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(2);
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
|
||||
let session_configured_event = SessionConfiguredEvent {
|
||||
session_id: Uuid::new_v4(),
|
||||
model: "gpt-4o".to_string(),
|
||||
history_log_id: 1,
|
||||
history_entry_count: 1000,
|
||||
};
|
||||
let event = Event {
|
||||
id: "1".to_string(),
|
||||
msg: EventMsg::SessionConfigured(session_configured_event.clone()),
|
||||
};
|
||||
let meta = OutgoingNotificationMeta {
|
||||
request_id: Some(RequestId::String("123".to_string())),
|
||||
};
|
||||
|
||||
outgoing_message_sender
|
||||
.send_event_as_notification(&event, Some(meta))
|
||||
.await;
|
||||
|
||||
let result = outgoing_rx.recv().await.unwrap();
|
||||
let OutgoingMessage::Notification(OutgoingNotification { method, params }) = result else {
|
||||
panic!("expected Notification for first message");
|
||||
};
|
||||
assert_eq!(method, "codex/event");
|
||||
let expected_params = json!({
|
||||
"_meta": {
|
||||
"requestId": "123",
|
||||
},
|
||||
"id": "1",
|
||||
"msg": {
|
||||
"session_id": session_configured_event.session_id,
|
||||
"model": session_configured_event.model,
|
||||
"history_log_id": session_configured_event.history_log_id,
|
||||
"history_entry_count": session_configured_event.history_entry_count,
|
||||
"type": "session_configured",
|
||||
}
|
||||
});
|
||||
assert_eq!(params.unwrap(), expected_params);
|
||||
|
||||
let result2 = outgoing_rx.recv().await.unwrap();
|
||||
let OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: method2,
|
||||
params: params2,
|
||||
}) = result2
|
||||
else {
|
||||
panic!("expected Notification for second message");
|
||||
};
|
||||
assert_eq!(method2, event.msg.to_string());
|
||||
assert_eq!(params2.unwrap(), expected_params);
|
||||
}
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PatchApprovalElicitRequestParams {
|
||||
pub message: String,
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
pub codex_call_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_grant_root: Option<PathBuf>,
|
||||
pub codex_changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct PatchApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_patch_approval_request(
|
||||
call_id: String,
|
||||
reason: Option<String>,
|
||||
grant_root: Option<PathBuf>,
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
codex: Arc<Codex>,
|
||||
request_id: RequestId,
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
) {
|
||||
let mut message_lines = Vec::new();
|
||||
if let Some(r) = &reason {
|
||||
message_lines.push(r.clone());
|
||||
}
|
||||
message_lines.push("Allow Codex to apply proposed code changes?".to_string());
|
||||
|
||||
let params = PatchApprovalElicitRequestParams {
|
||||
message: message_lines.join("\n"),
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "patch-approval".to_string(),
|
||||
codex_mcp_tool_call_id: tool_call_id.clone(),
|
||||
codex_event_id: event_id.clone(),
|
||||
codex_call_id: call_id,
|
||||
codex_reason: reason,
|
||||
codex_grant_root: grant_root,
|
||||
codex_changes: changes,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
let message = format!("Failed to serialize PatchApprovalElicitRequestParams: {err}");
|
||||
error!("{message}");
|
||||
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id.clone(),
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_PARAMS_ERROR_CODE,
|
||||
message,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let on_response = outgoing
|
||||
.send_request(ElicitRequest::METHOD, Some(params_json))
|
||||
.await;
|
||||
|
||||
// Listen for the response on a separate task so we don't block the main agent loop.
|
||||
{
|
||||
let codex = codex.clone();
|
||||
let event_id = event_id.clone();
|
||||
tokio::spawn(async move {
|
||||
on_patch_approval_response(event_id, on_response, codex).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn on_patch_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<Codex>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("request failed: {err:?}");
|
||||
if let Err(submit_err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id.clone(),
|
||||
decision: ReviewDecision::Denied,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit denied PatchApproval after request failure: {submit_err}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let response = serde_json::from_value::<PatchApprovalResponse>(value).unwrap_or_else(|err| {
|
||||
error!("failed to deserialize PatchApprovalResponse: {err}");
|
||||
PatchApprovalResponse {
|
||||
decision: ReviewDecision::Denied,
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit PatchApproval: {err}");
|
||||
}
|
||||
}
|
||||
@@ -1,440 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_mcp_server::ExecApprovalElicitRequestParams;
|
||||
use codex_mcp_server::ExecApprovalResponse;
|
||||
use codex_mcp_server::PatchApprovalElicitRequestParams;
|
||||
use codex_mcp_server::PatchApprovalResponse;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::MockServer;
|
||||
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_apply_patch_sse_response;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_test_support::create_shell_sse_response;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
/// Test that a shell command that is not on the "trusted" list triggers an
|
||||
/// elicitation request to the MCP and that sending the approval runs the
|
||||
/// command, as expected.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shell_command_approval_triggers_elicitation() {
|
||||
if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Apparently `#[tokio::test]` must return `()`, so we create a helper
|
||||
// function that returns `Result` so we can use `?` in favor of `unwrap`.
|
||||
if let Err(err) = shell_command_approval_triggers_elicitation().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
// We use `git init` because it will not be on the "trusted" list.
|
||||
let shell_command = vec!["git".to_string(), "init".to_string()];
|
||||
let workdir_for_shell_function_call = TempDir::new()?;
|
||||
|
||||
let McpHandle {
|
||||
process: mut mcp_process,
|
||||
server: _server,
|
||||
dir: _dir,
|
||||
} = create_mcp_process(vec![
|
||||
create_shell_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(workdir_for_shell_function_call.path()),
|
||||
Some(5_000),
|
||||
"call1234",
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("Enjoy your new git repo!")?,
|
||||
])
|
||||
.await?;
|
||||
|
||||
// Send a "codex" tool request, which should hit the completions endpoint.
|
||||
// In turn, it should reply with a tool call, which the MCP should forward
|
||||
// as an elicitation.
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
prompt: "run `git init`".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let elicitation_request = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_request_message(),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// This is the first request from the server, so the id should be 0 given
|
||||
// how things are currently implemented.
|
||||
let elicitation_request_id = RequestId::Integer(0);
|
||||
let expected_elicitation_request = create_expected_elicitation_request(
|
||||
elicitation_request_id.clone(),
|
||||
shell_command.clone(),
|
||||
workdir_for_shell_function_call.path(),
|
||||
codex_request_id.to_string(),
|
||||
// Internal Codex id: empirically it is 1, but this is
|
||||
// admittedly an internal detail that could change.
|
||||
"1".to_string(),
|
||||
)?;
|
||||
assert_eq!(expected_elicitation_request, elicitation_request);
|
||||
|
||||
// Accept the `git init` request by responding to the elicitation.
|
||||
mcp_process
|
||||
.send_response(
|
||||
elicitation_request_id,
|
||||
serde_json::to_value(ExecApprovalResponse {
|
||||
decision: ReviewDecision::Approved,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Verify the original `codex` tool call completes and that `git init` ran
|
||||
// successfully.
|
||||
let codex_response = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(codex_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "Enjoy your new git repo!",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}),
|
||||
},
|
||||
codex_response
|
||||
);
|
||||
|
||||
assert!(
|
||||
workdir_for_shell_function_call.path().join(".git").is_dir(),
|
||||
".git folder should have been created"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_expected_elicitation_request(
|
||||
elicitation_request_id: RequestId,
|
||||
command: Vec<String>,
|
||||
workdir: &Path,
|
||||
codex_mcp_tool_call_id: String,
|
||||
codex_event_id: String,
|
||||
) -> anyhow::Result<JSONRPCRequest> {
|
||||
let expected_message = format!(
|
||||
"Allow Codex to run `{}` in `{}`?",
|
||||
shlex::try_join(command.iter().map(|s| s.as_ref()))?,
|
||||
workdir.to_string_lossy()
|
||||
);
|
||||
Ok(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: elicitation_request_id,
|
||||
method: ElicitRequest::METHOD.to_string(),
|
||||
params: Some(serde_json::to_value(&ExecApprovalElicitRequestParams {
|
||||
message: expected_message,
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "exec-approval".to_string(),
|
||||
codex_mcp_tool_call_id,
|
||||
codex_event_id,
|
||||
codex_command: command,
|
||||
codex_cwd: workdir.to_path_buf(),
|
||||
codex_call_id: "call1234".to_string(),
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
/// Test that patch approval triggers an elicitation request to the MCP and that
|
||||
/// sending the approval applies the patch, as expected.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_patch_approval_triggers_elicitation() {
|
||||
if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(err) = patch_approval_triggers_elicitation().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
let cwd = TempDir::new()?;
|
||||
let test_file = cwd.path().join("destination_file.txt");
|
||||
std::fs::write(&test_file, "original content\n")?;
|
||||
|
||||
let patch_content = format!(
|
||||
"*** Begin Patch\n*** Update File: {}\n-original content\n+modified content\n*** End Patch",
|
||||
test_file.as_path().to_string_lossy()
|
||||
);
|
||||
|
||||
let McpHandle {
|
||||
process: mut mcp_process,
|
||||
server: _server,
|
||||
dir: _dir,
|
||||
} = create_mcp_process(vec![
|
||||
create_apply_patch_sse_response(&patch_content, "call1234")?,
|
||||
create_final_assistant_message_sse_response("Patch has been applied successfully!")?,
|
||||
])
|
||||
.await?;
|
||||
|
||||
// Send a "codex" tool request that will trigger the apply_patch command
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
cwd: Some(cwd.path().to_string_lossy().to_string()),
|
||||
prompt: "please modify the test file".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let elicitation_request = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_request_message(),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let elicitation_request_id = RequestId::Integer(0);
|
||||
|
||||
let mut expected_changes = HashMap::new();
|
||||
expected_changes.insert(
|
||||
test_file.as_path().to_path_buf(),
|
||||
FileChange::Update {
|
||||
unified_diff: "@@ -1 +1 @@\n-original content\n+modified content\n".to_string(),
|
||||
move_path: None,
|
||||
},
|
||||
);
|
||||
|
||||
let expected_elicitation_request = create_expected_patch_approval_elicitation_request(
|
||||
elicitation_request_id.clone(),
|
||||
expected_changes,
|
||||
None, // No grant_root expected
|
||||
None, // No reason expected
|
||||
codex_request_id.to_string(),
|
||||
"1".to_string(),
|
||||
)?;
|
||||
assert_eq!(expected_elicitation_request, elicitation_request);
|
||||
|
||||
// Accept the patch approval request by responding to the elicitation
|
||||
mcp_process
|
||||
.send_response(
|
||||
elicitation_request_id,
|
||||
serde_json::to_value(PatchApprovalResponse {
|
||||
decision: ReviewDecision::Approved,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Verify the original `codex` tool call completes
|
||||
let codex_response = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(codex_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "Patch has been applied successfully!",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}),
|
||||
},
|
||||
codex_response
|
||||
);
|
||||
|
||||
let file_contents = std::fs::read_to_string(test_file.as_path())?;
|
||||
assert_eq!(file_contents, "modified content\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_codex_tool_passes_base_instructions() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Apparently `#[tokio::test]` must return `()`, so we create a helper
|
||||
// function that returns `Result` so we can use `?` in favor of `unwrap`.
|
||||
if let Err(err) = codex_tool_passes_base_instructions().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn codex_tool_passes_base_instructions() -> anyhow::Result<()> {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
let server =
|
||||
create_mock_chat_completions_server(vec![create_final_assistant_message_sse_response(
|
||||
"Enjoy!",
|
||||
)?])
|
||||
.await;
|
||||
|
||||
// Run `codex mcp` with a specific config.toml.
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
|
||||
// Send a "codex" tool request, which should hit the completions endpoint.
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
prompt: "How are you?".to_string(),
|
||||
base_instructions: Some("You are a helpful assistant.".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
|
||||
let codex_response = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(codex_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "Enjoy!",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}),
|
||||
},
|
||||
codex_response
|
||||
);
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
let request = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let instructions = request["messages"][0]["content"].as_str().unwrap();
|
||||
assert!(instructions.starts_with("You are a helpful assistant."));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_expected_patch_approval_elicitation_request(
|
||||
elicitation_request_id: RequestId,
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
grant_root: Option<PathBuf>,
|
||||
reason: Option<String>,
|
||||
codex_mcp_tool_call_id: String,
|
||||
codex_event_id: String,
|
||||
) -> anyhow::Result<JSONRPCRequest> {
|
||||
let mut message_lines = Vec::new();
|
||||
if let Some(r) = &reason {
|
||||
message_lines.push(r.clone());
|
||||
}
|
||||
message_lines.push("Allow Codex to apply proposed code changes?".to_string());
|
||||
|
||||
Ok(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: elicitation_request_id,
|
||||
method: ElicitRequest::METHOD.to_string(),
|
||||
params: Some(serde_json::to_value(&PatchApprovalElicitRequestParams {
|
||||
message: message_lines.join("\n"),
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "patch-approval".to_string(),
|
||||
codex_mcp_tool_call_id,
|
||||
codex_event_id,
|
||||
codex_reason: reason,
|
||||
codex_grant_root: grant_root,
|
||||
codex_changes: changes,
|
||||
codex_call_id: "call1234".to_string(),
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
/// This handle is used to ensure that the MockServer and TempDir are not dropped while
|
||||
/// the McpProcess is still running.
|
||||
pub struct McpHandle {
|
||||
pub process: McpProcess,
|
||||
/// Retain the server for the lifetime of the McpProcess.
|
||||
#[allow(dead_code)]
|
||||
server: MockServer,
|
||||
/// Retain the temporary directory for the lifetime of the McpProcess.
|
||||
#[allow(dead_code)]
|
||||
dir: TempDir,
|
||||
}
|
||||
|
||||
async fn create_mcp_process(responses: Vec<String>) -> anyhow::Result<McpHandle> {
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
Ok(McpHandle {
|
||||
process: mcp_process,
|
||||
server,
|
||||
dir: codex_home,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a Codex config that uses the mock server as the model provider.
|
||||
/// It also uses `approval_policy = "untrusted"` so that we exercise the
|
||||
/// elicitation code path for shell commands.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "untrusted"
|
||||
sandbox_policy = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
[package]
|
||||
name = "mcp_test_support"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
assert_cmd = "2"
|
||||
codex-mcp-server = { path = "../.." }
|
||||
mcp-types = { path = "../../../mcp-types" }
|
||||
pretty_assertions = "1.4.1"
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
tempfile = "3"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
] }
|
||||
wiremock = "0.6"
|
||||
@@ -1,9 +0,0 @@
|
||||
mod mcp_process;
|
||||
mod mock_model_server;
|
||||
mod responses;
|
||||
|
||||
pub use mcp_process::McpProcess;
|
||||
pub use mock_model_server::create_mock_chat_completions_server;
|
||||
pub use responses::create_apply_patch_sse_response;
|
||||
pub use responses::create_final_assistant_message_sse_response;
|
||||
pub use responses::create_shell_sse_response;
|
||||
@@ -1,343 +0,0 @@
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::ChildStdin;
|
||||
use tokio::process::ChildStdout;
|
||||
|
||||
use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_mcp_server::CodexToolCallReplyParam;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::ModelContextProtocolNotification;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::process::Command as StdCommand;
|
||||
use tokio::process::Command;
|
||||
|
||||
pub struct McpProcess {
|
||||
next_request_id: AtomicI64,
|
||||
/// Retain this child process until the client is dropped. The Tokio runtime
|
||||
/// will make a "best effort" to reap the process after it exits, but it is
|
||||
/// not a guarantee. See the `kill_on_drop` documentation for details.
|
||||
#[allow(dead_code)]
|
||||
process: Child,
|
||||
stdin: ChildStdin,
|
||||
stdout: BufReader<ChildStdout>,
|
||||
}
|
||||
|
||||
impl McpProcess {
|
||||
pub async fn new(codex_home: &Path) -> anyhow::Result<Self> {
|
||||
// Use assert_cmd to locate the binary path and then switch to tokio::process::Command
|
||||
let std_cmd = StdCommand::cargo_bin("codex-mcp-server")
|
||||
.context("should find binary for codex-mcp-server")?;
|
||||
|
||||
let program = std_cmd.get_program().to_owned();
|
||||
|
||||
let mut cmd = Command::new(program);
|
||||
|
||||
cmd.stdin(Stdio::piped());
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.env("CODEX_HOME", codex_home);
|
||||
cmd.env("RUST_LOG", "debug");
|
||||
|
||||
let mut process = cmd
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.context("codex-mcp-server proc should start")?;
|
||||
let stdin = process
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::format_err!("mcp should have stdin fd"))?;
|
||||
let stdout = process
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::format_err!("mcp should have stdout fd"))?;
|
||||
let stdout = BufReader::new(stdout);
|
||||
Ok(Self {
|
||||
next_request_id: AtomicI64::new(0),
|
||||
process,
|
||||
stdin,
|
||||
stdout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Performs the initialization handshake with the MCP server.
|
||||
pub async fn initialize(&mut self) -> anyhow::Result<()> {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let params = InitializeRequestParams {
|
||||
capabilities: ClientCapabilities {
|
||||
elicitation: Some(json!({})),
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "elicitation test".into(),
|
||||
title: Some("Elicitation Test".into()),
|
||||
version: "0.0.0".into(),
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.into(),
|
||||
};
|
||||
let params_value = serde_json::to_value(params)?;
|
||||
|
||||
self.send_jsonrpc_message(JSONRPCMessage::Request(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(request_id),
|
||||
method: mcp_types::InitializeRequest::METHOD.into(),
|
||||
params: Some(params_value),
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let initialized = self.read_jsonrpc_message().await?;
|
||||
assert_eq!(
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(request_id),
|
||||
result: json!({
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": true
|
||||
},
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "codex-mcp-server",
|
||||
"title": "Codex",
|
||||
"version": "0.0.0"
|
||||
},
|
||||
"protocolVersion": mcp_types::MCP_SCHEMA_VERSION
|
||||
})
|
||||
}),
|
||||
initialized
|
||||
);
|
||||
|
||||
// Send notifications/initialized to ack the response.
|
||||
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method: mcp_types::InitializedNotification::METHOD.into(),
|
||||
params: None,
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the id used to make the request so it can be used when
|
||||
/// correlating notifications.
|
||||
pub async fn send_codex_tool_call(
|
||||
&mut self,
|
||||
params: CodexToolCallParam,
|
||||
) -> anyhow::Result<i64> {
|
||||
let codex_tool_call_params = CallToolRequestParams {
|
||||
name: "codex".to_string(),
|
||||
arguments: Some(serde_json::to_value(params)?),
|
||||
};
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(codex_tool_call_params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_codex_reply_tool_call(
|
||||
&mut self,
|
||||
session_id: &str,
|
||||
prompt: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let codex_tool_call_params = CallToolRequestParams {
|
||||
name: "codex-reply".to_string(),
|
||||
arguments: Some(serde_json::to_value(CodexToolCallReplyParam {
|
||||
prompt: prompt.to_string(),
|
||||
session_id: session_id.to_string(),
|
||||
})?),
|
||||
};
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(codex_tool_call_params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> anyhow::Result<i64> {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(request_id),
|
||||
method: method.to_string(),
|
||||
params,
|
||||
});
|
||||
self.send_jsonrpc_message(message).await?;
|
||||
Ok(request_id)
|
||||
}
|
||||
|
||||
pub async fn send_response(
|
||||
&mut self,
|
||||
id: RequestId,
|
||||
result: serde_json::Value,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_jsonrpc_message(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result,
|
||||
}))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_jsonrpc_message(&mut self, message: JSONRPCMessage) -> anyhow::Result<()> {
|
||||
let payload = serde_json::to_string(&message)?;
|
||||
self.stdin.write_all(payload.as_bytes()).await?;
|
||||
self.stdin.write_all(b"\n").await?;
|
||||
self.stdin.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_jsonrpc_message(&mut self) -> anyhow::Result<JSONRPCMessage> {
|
||||
let mut line = String::new();
|
||||
self.stdout.read_line(&mut line).await?;
|
||||
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
|
||||
Ok(message)
|
||||
}
|
||||
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
eprint!("message: {message:?}");
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(_) => {
|
||||
eprintln!("notification: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Request(jsonrpc_request) => {
|
||||
return Ok(jsonrpc_request);
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_response_message(
|
||||
&mut self,
|
||||
request_id: RequestId,
|
||||
) -> anyhow::Result<JSONRPCResponse> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
eprint!("message: {message:?}");
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(_) => {
|
||||
eprintln!("notification: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(jsonrpc_response) => {
|
||||
if jsonrpc_response.id == request_id {
|
||||
return Ok(jsonrpc_response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_configured_response_message(
|
||||
&mut self,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut sid_old: Option<String> = None;
|
||||
let mut sid_new: Option<String> = None;
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
eprint!("message: {message:?}");
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
if let Some(params) = notification.params {
|
||||
// Back-compat schema: method == "codex/event" and msg.type == "session_configured"
|
||||
if notification.method == "codex/event" {
|
||||
if let Some(msg) = params.get("msg") {
|
||||
if msg.get("type").and_then(|v| v.as_str())
|
||||
== Some("session_configured")
|
||||
{
|
||||
if let Some(session_id) =
|
||||
msg.get("session_id").and_then(|v| v.as_str())
|
||||
{
|
||||
sid_old = Some(session_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// New schema: method is the Display of EventMsg::SessionConfigured => "SessionConfigured"
|
||||
if notification.method == "session_configured" {
|
||||
if let Some(msg) = params.get("msg") {
|
||||
if let Some(session_id) =
|
||||
msg.get("session_id").and_then(|v| v.as_str())
|
||||
{
|
||||
sid_new = Some(session_id.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sid_old.is_some() && sid_new.is_some() {
|
||||
// Both seen, they must match
|
||||
assert_eq!(
|
||||
sid_old.as_ref().unwrap(),
|
||||
sid_new.as_ref().unwrap(),
|
||||
"session_id mismatch between old and new schema"
|
||||
);
|
||||
return Ok(sid_old.unwrap());
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_notification(
|
||||
&mut self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method: method.to_string(),
|
||||
params,
|
||||
}))
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
/// Create a mock server that will provide the responses, in order, for
|
||||
/// requests to the `/v1/chat/completions` endpoint.
|
||||
pub async fn create_mock_chat_completions_server(responses: Vec<String>) -> MockServer {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let num_calls = responses.len();
|
||||
let seq_responder = SeqResponder {
|
||||
num_calls: AtomicUsize::new(0),
|
||||
responses,
|
||||
};
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(seq_responder)
|
||||
.expect(num_calls as u64)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
server
|
||||
}
|
||||
|
||||
struct SeqResponder {
|
||||
num_calls: AtomicUsize,
|
||||
responses: Vec<String>,
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
||||
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
||||
match self.responses.get(call_num) {
|
||||
Some(response) => ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(response.clone(), "text/event-stream"),
|
||||
None => panic!("no response for {call_num}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
|
||||
pub fn create_shell_sse_response(
|
||||
command: Vec<String>,
|
||||
workdir: Option<&Path>,
|
||||
timeout_ms: Option<u64>,
|
||||
call_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// The `arguments`` for the `shell` tool is a serialized JSON object.
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"command": command,
|
||||
"workdir": workdir.map(|w| w.to_string_lossy()),
|
||||
"timeout": timeout_ms
|
||||
}))?;
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result<String> {
|
||||
let assistant_message = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": message
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&assistant_message)?
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
pub fn create_apply_patch_sse_response(
|
||||
patch_content: &str,
|
||||
call_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// Use shell command to call apply_patch with heredoc format
|
||||
let shell_command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF");
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"command": ["bash", "-lc", shell_command]
|
||||
}))?;
|
||||
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
#![cfg(unix)]
|
||||
// Support code lives in the `mcp_test_support` crate under tests/common.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_test_support::create_shell_sse_response;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shell_command_interruption() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(err) = shell_command_interruption().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn shell_command_interruption() -> anyhow::Result<()> {
|
||||
// Use a cross-platform blocking command. On Windows plain `sleep` is not guaranteed to exist
|
||||
// (MSYS/GNU coreutils may be absent) and the failure causes the tool call to finish immediately,
|
||||
// which triggers a second model request before the test sends the explicit follow-up. That
|
||||
// prematurely consumes the second mocked SSE response and leads to a third POST (panic: no response for 2).
|
||||
// Powershell Start-Sleep is always available on Windows runners. On Unix we keep using `sleep`.
|
||||
#[cfg(target_os = "windows")]
|
||||
let shell_command = vec![
|
||||
"powershell".to_string(),
|
||||
"-Command".to_string(),
|
||||
"Start-Sleep -Seconds 60".to_string(),
|
||||
];
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
let shell_command = vec!["sleep".to_string(), "60".to_string()];
|
||||
let workdir_for_shell_function_call = TempDir::new()?;
|
||||
|
||||
// Create mock server with a single SSE response: the long sleep command
|
||||
let server = create_mock_chat_completions_server(vec![
|
||||
create_shell_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(workdir_for_shell_function_call.path()),
|
||||
Some(60_000), // 60 seconds timeout in ms
|
||||
"call_sleep",
|
||||
)?,
|
||||
create_shell_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(workdir_for_shell_function_call.path()),
|
||||
Some(60_000), // 60 seconds timeout in ms
|
||||
"call_sleep",
|
||||
)?,
|
||||
])
|
||||
.await;
|
||||
|
||||
// Create Codex configuration
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
|
||||
// Send codex tool call that triggers "sleep 60"
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
cwd: None,
|
||||
prompt: "First Run: run `sleep 60`".to_string(),
|
||||
model: None,
|
||||
profile: None,
|
||||
approval_policy: None,
|
||||
sandbox: None,
|
||||
config: None,
|
||||
base_instructions: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let session_id = mcp_process
|
||||
.read_stream_until_configured_response_message()
|
||||
.await?;
|
||||
|
||||
// Give the command a moment to start
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
// Send interrupt notification
|
||||
mcp_process
|
||||
.send_notification(
|
||||
"notifications/cancelled",
|
||||
Some(json!({ "requestId": codex_request_id })),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Expect Codex to return an error or interruption response
|
||||
let codex_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
assert!(
|
||||
codex_response
|
||||
.result
|
||||
.as_object()
|
||||
.map(|o| o.contains_key("error"))
|
||||
.unwrap_or(false),
|
||||
"Expected an interruption or error result, got: {codex_response:?}"
|
||||
);
|
||||
|
||||
let codex_reply_request_id = mcp_process
|
||||
.send_codex_reply_tool_call(&session_id, "Second Run: run `sleep 60`")
|
||||
.await?;
|
||||
|
||||
// Give the command a moment to start
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
// Send interrupt notification
|
||||
mcp_process
|
||||
.send_notification(
|
||||
"notifications/cancelled",
|
||||
Some(json!({ "requestId": codex_reply_request_id })),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Expect Codex to return an error or interruption response
|
||||
let codex_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_reply_request_id)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
assert!(
|
||||
codex_response
|
||||
.result
|
||||
.as_object()
|
||||
.map(|o| o.contains_key("error"))
|
||||
.unwrap_or(false),
|
||||
"Expected an interruption or error result, got: {codex_response:?}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Types for Model Context Protocol. Inspired by https://crates.io/crates/lsp-types.
|
||||
|
||||
As documented on https://modelcontextprotocol.io/specification/2025-06-18/basic:
|
||||
As documented on https://modelcontextprotocol.io/specification/2025-03-26/basic:
|
||||
|
||||
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.ts
|
||||
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.json
|
||||
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts
|
||||
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa: E501
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -14,13 +13,10 @@ from pathlib import Path
|
||||
# Helper first so it is defined when other functions call it.
|
||||
from typing import Any, Literal
|
||||
|
||||
SCHEMA_VERSION = "2025-06-18"
|
||||
SCHEMA_VERSION = "2025-03-26"
|
||||
JSONRPC_VERSION = "2.0"
|
||||
|
||||
STANDARD_DERIVE = "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]\n"
|
||||
STANDARD_HASHABLE_DERIVE = (
|
||||
"#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]\n"
|
||||
)
|
||||
|
||||
# Will be populated with the schema's `definitions` map in `main()` so that
|
||||
# helper functions (for example `define_any_of`) can perform look-ups while
|
||||
@@ -30,27 +26,19 @@ DEFINITIONS: dict[str, Any] = {}
|
||||
CLIENT_REQUEST_TYPE_NAMES: list[str] = []
|
||||
# Concrete *Notification types that make up the ServerNotification enum.
|
||||
SERVER_NOTIFICATION_TYPE_NAMES: list[str] = []
|
||||
# Enum types that will need a `allow(clippy::large_enum_variant)` annotation in
|
||||
# order to compile without warnings.
|
||||
LARGE_ENUMS = {"ServerResult"}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Embed, cluster and analyse text prompts via the OpenAI API.",
|
||||
)
|
||||
|
||||
default_schema_file = (
|
||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"schema_file",
|
||||
nargs="?",
|
||||
default=default_schema_file,
|
||||
help="schema.json file to process",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
schema_file = args.schema_file
|
||||
num_args = len(sys.argv)
|
||||
if num_args == 1:
|
||||
schema_file = (
|
||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||
)
|
||||
elif num_args == 2:
|
||||
schema_file = Path(sys.argv[1])
|
||||
else:
|
||||
print("Usage: python3 codegen.py <schema.json>")
|
||||
return 1
|
||||
|
||||
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
||||
|
||||
@@ -209,8 +197,6 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non
|
||||
if name.endswith("Result"):
|
||||
out.extend(f"impl From<{name}> for serde_json::Value {{\n")
|
||||
out.append(f" fn from(value: {name}) -> Self {{\n")
|
||||
out.append(" // Leave this as it should never fail\n")
|
||||
out.append(" #[expect(clippy::unwrap_used)]\n")
|
||||
out.append(" serde_json::to_value(value).unwrap()\n")
|
||||
out.append(" }\n")
|
||||
out.append("}\n\n")
|
||||
@@ -225,7 +211,20 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non
|
||||
any_of = definition.get("anyOf", [])
|
||||
if any_of:
|
||||
assert isinstance(any_of, list)
|
||||
out.extend(define_any_of(name, any_of, description))
|
||||
if name == "JSONRPCMessage":
|
||||
# Special case for JSONRPCMessage because its definition in the
|
||||
# JSON schema does not quite match how we think about this type
|
||||
# definition in Rust.
|
||||
deep_copied_any_of = json.loads(json.dumps(any_of))
|
||||
deep_copied_any_of[2] = {
|
||||
"$ref": "#/definitions/JSONRPCBatchRequest",
|
||||
}
|
||||
deep_copied_any_of[5] = {
|
||||
"$ref": "#/definitions/JSONRPCBatchResponse",
|
||||
}
|
||||
out.extend(define_any_of(name, deep_copied_any_of, description))
|
||||
else:
|
||||
out.extend(define_any_of(name, any_of, description))
|
||||
return
|
||||
|
||||
type_prop = definition.get("type", None)
|
||||
@@ -394,7 +393,7 @@ def define_string_enum(
|
||||
|
||||
|
||||
def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> None:
|
||||
out.append(STANDARD_HASHABLE_DERIVE)
|
||||
out.append(STANDARD_DERIVE)
|
||||
out.append("#[serde(untagged)]\n")
|
||||
out.append(f"pub enum {name} {{\n")
|
||||
for simple_type in type_list:
|
||||
@@ -440,8 +439,6 @@ def define_any_of(
|
||||
if serde := get_serde_annotation_for_anyof_type(name):
|
||||
out.append(serde + "\n")
|
||||
|
||||
if name in LARGE_ENUMS:
|
||||
out.append("#[allow(clippy::large_enum_variant)]\n")
|
||||
out.append(f"pub enum {name} {{\n")
|
||||
|
||||
if name == "ClientRequest":
|
||||
@@ -599,8 +596,6 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
||||
prop_name = "r#type"
|
||||
elif name == "ref":
|
||||
prop_name = "r#ref"
|
||||
elif name == "enum":
|
||||
prop_name = "r#enum"
|
||||
elif snake_case := to_snake_case(name):
|
||||
prop_name = snake_case
|
||||
is_rename = True
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
pub const MCP_SCHEMA_VERSION: &str = "2025-06-18";
|
||||
pub const MCP_SCHEMA_VERSION: &str = "2025-03-26";
|
||||
pub const JSONRPC_VERSION: &str = "2.0";
|
||||
|
||||
/// Paired request/response types for the Model Context Protocol (MCP).
|
||||
@@ -35,12 +35,6 @@ fn default_jsonrpc() -> String {
|
||||
pub struct Annotations {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub audience: Option<Vec<Role>>,
|
||||
#[serde(
|
||||
rename = "lastModified",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub last_modified: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<f64>,
|
||||
}
|
||||
@@ -56,14 +50,6 @@ pub struct AudioContent {
|
||||
pub r#type: String, // &'static str = "audio"
|
||||
}
|
||||
|
||||
/// Base interface for metadata with name (identifier) and title (display name) properties.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BaseMetadata {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BlobResourceContents {
|
||||
pub blob: String,
|
||||
@@ -72,17 +58,6 @@ pub struct BlobResourceContents {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BooleanSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "boolean"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum CallToolRequest {}
|
||||
|
||||
@@ -100,17 +75,29 @@ pub struct CallToolRequestParams {
|
||||
}
|
||||
|
||||
/// The server's response to a tool call.
|
||||
///
|
||||
/// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
/// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
/// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
/// and self-correct.
|
||||
///
|
||||
/// However, any errors in _finding_ the tool, an error indicating that the
|
||||
/// server does not support tool calls, or any other exceptional conditions,
|
||||
/// should be reported as an MCP error response.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CallToolResult {
|
||||
pub content: Vec<ContentBlock>,
|
||||
pub content: Vec<CallToolResultContent>,
|
||||
#[serde(rename = "isError", default, skip_serializing_if = "Option::is_none")]
|
||||
pub is_error: Option<bool>,
|
||||
#[serde(
|
||||
rename = "structuredContent",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub structured_content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CallToolResultContent {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
impl From<CallToolResult> for serde_json::Value {
|
||||
@@ -140,8 +127,6 @@ pub struct CancelledNotificationParams {
|
||||
/// Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ClientCapabilities {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub elicitation: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub experimental: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
@@ -209,7 +194,6 @@ pub enum ClientResult {
|
||||
Result(Result),
|
||||
CreateMessageResult(CreateMessageResult),
|
||||
ListRootsResult(ListRootsResult),
|
||||
ElicitResult(ElicitResult),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -224,18 +208,9 @@ impl ModelContextProtocolRequest for CompleteRequest {
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParams {
|
||||
pub argument: CompleteRequestParamsArgument,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub context: Option<CompleteRequestParamsContext>,
|
||||
pub r#ref: CompleteRequestParamsRef,
|
||||
}
|
||||
|
||||
/// Additional, optional context for completions
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParamsContext {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// The argument's information
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParamsArgument {
|
||||
@@ -247,7 +222,7 @@ pub struct CompleteRequestParamsArgument {
|
||||
#[serde(untagged)]
|
||||
pub enum CompleteRequestParamsRef {
|
||||
PromptReference(PromptReference),
|
||||
ResourceTemplateReference(ResourceTemplateReference),
|
||||
ResourceReference(ResourceReference),
|
||||
}
|
||||
|
||||
/// The server's response to a completion/complete request
|
||||
@@ -273,16 +248,6 @@ impl From<CompleteResult> for serde_json::Value {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentBlock {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
ResourceLink(ResourceLink),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum CreateMessageRequest {}
|
||||
|
||||
@@ -360,48 +325,6 @@ impl From<CreateMessageResult> for serde_json::Value {
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct Cursor(String);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ElicitRequest {}
|
||||
|
||||
impl ModelContextProtocolRequest for ElicitRequest {
|
||||
const METHOD: &'static str = "elicitation/create";
|
||||
type Params = ElicitRequestParams;
|
||||
type Result = ElicitResult;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitRequestParams {
|
||||
pub message: String,
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
}
|
||||
|
||||
/// A restricted subset of JSON Schema.
|
||||
/// Only top-level properties are allowed, without nesting.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitRequestParamsRequestedSchema {
|
||||
pub properties: serde_json::Value,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<Vec<String>>,
|
||||
pub r#type: String, // &'static str = "object"
|
||||
}
|
||||
|
||||
/// The client's response to an elicitation request.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitResult {
|
||||
pub action: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl From<ElicitResult> for serde_json::Value {
|
||||
fn from(value: ElicitResult) -> Self {
|
||||
// Leave this as it should never fail
|
||||
#[expect(clippy::unwrap_used)]
|
||||
serde_json::to_value(value).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// The contents of a resource, embedded into a prompt or tool call result.
|
||||
///
|
||||
/// It is up to the client how best to render embedded resources for the benefit
|
||||
@@ -423,18 +346,6 @@ pub enum EmbeddedResourceResource {
|
||||
|
||||
pub type EmptyResult = Result;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct EnumSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub r#enum: Vec<String>,
|
||||
#[serde(rename = "enumNames", default, skip_serializing_if = "Option::is_none")]
|
||||
pub enum_names: Option<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "string"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum GetPromptRequest {}
|
||||
|
||||
@@ -478,12 +389,10 @@ pub struct ImageContent {
|
||||
pub r#type: String, // &'static str = "image"
|
||||
}
|
||||
|
||||
/// Describes the name and version of an MCP implementation, with an optional title for UI representation.
|
||||
/// Describes the name and version of an MCP implementation.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct Implementation {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
@@ -533,6 +442,24 @@ impl ModelContextProtocolNotification for InitializedNotification {
|
||||
type Params = Option<serde_json::Value>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JSONRPCBatchRequestItem {
|
||||
JSONRPCRequest(JSONRPCRequest),
|
||||
JSONRPCNotification(JSONRPCNotification),
|
||||
}
|
||||
|
||||
pub type JSONRPCBatchRequest = Vec<JSONRPCBatchRequestItem>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JSONRPCBatchResponseItem {
|
||||
JSONRPCResponse(JSONRPCResponse),
|
||||
JSONRPCError(JSONRPCError),
|
||||
}
|
||||
|
||||
pub type JSONRPCBatchResponse = Vec<JSONRPCBatchResponseItem>;
|
||||
|
||||
/// A response to a request that indicates an error occurred.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct JSONRPCError {
|
||||
@@ -556,8 +483,10 @@ pub struct JSONRPCErrorError {
|
||||
pub enum JSONRPCMessage {
|
||||
Request(JSONRPCRequest),
|
||||
Notification(JSONRPCNotification),
|
||||
BatchRequest(JSONRPCBatchRequest),
|
||||
Response(JSONRPCResponse),
|
||||
Error(JSONRPCError),
|
||||
BatchResponse(JSONRPCBatchResponse),
|
||||
}
|
||||
|
||||
/// A notification which does not expect a response.
|
||||
@@ -848,19 +777,6 @@ pub struct Notification {
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct NumberSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub maximum: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub minimum: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PaginatedRequest {
|
||||
pub method: String,
|
||||
@@ -901,17 +817,6 @@ impl ModelContextProtocolRequest for PingRequest {
|
||||
type Result = Result;
|
||||
}
|
||||
|
||||
/// Restricted schema definitions that only allow primitive types
|
||||
/// without nested objects or arrays.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PrimitiveSchemaDefinition {
|
||||
StringSchema(StringSchema),
|
||||
NumberSchema(NumberSchema),
|
||||
BooleanSchema(BooleanSchema),
|
||||
EnumSchema(EnumSchema),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ProgressNotification {}
|
||||
|
||||
@@ -931,7 +836,7 @@ pub struct ProgressNotificationParams {
|
||||
pub total: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProgressToken {
|
||||
String(String),
|
||||
@@ -946,8 +851,6 @@ pub struct Prompt {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
/// Describes an argument that a prompt can accept.
|
||||
@@ -958,8 +861,6 @@ pub struct PromptArgument {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -976,16 +877,23 @@ impl ModelContextProtocolNotification for PromptListChangedNotification {
|
||||
/// resources from the MCP server.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PromptMessage {
|
||||
pub content: ContentBlock,
|
||||
pub content: PromptMessageContent,
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PromptMessageContent {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
/// Identifies a prompt.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PromptReference {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "ref/prompt"
|
||||
}
|
||||
|
||||
@@ -1031,7 +939,7 @@ pub struct Request {
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum RequestId {
|
||||
String(String),
|
||||
@@ -1050,8 +958,6 @@ pub struct Resource {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
@@ -1063,26 +969,6 @@ pub struct ResourceContents {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
/// A resource that the server is capable of reading, included in a prompt or tool call result.
|
||||
///
|
||||
/// Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceLink {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub annotations: Option<Annotations>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "resource_link"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ResourceListChangedNotification {}
|
||||
|
||||
@@ -1091,6 +977,13 @@ impl ModelContextProtocolNotification for ResourceListChangedNotification {
|
||||
type Params = Option<serde_json::Value>;
|
||||
}
|
||||
|
||||
/// A reference to a resource or resource template definition.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceReference {
|
||||
pub r#type: String, // &'static str = "ref/resource"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
/// A template description for resources available on the server.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceTemplate {
|
||||
@@ -1101,19 +994,10 @@ pub struct ResourceTemplate {
|
||||
#[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
#[serde(rename = "uriTemplate")]
|
||||
pub uri_template: String,
|
||||
}
|
||||
|
||||
/// A reference to a resource or resource template definition.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceTemplateReference {
|
||||
pub r#type: String, // &'static str = "ref/resource"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ResourceUpdatedNotification {}
|
||||
|
||||
@@ -1256,7 +1140,6 @@ pub enum ServerRequest {
|
||||
PingRequest(PingRequest),
|
||||
CreateMessageRequest(CreateMessageRequest),
|
||||
ListRootsRequest(ListRootsRequest),
|
||||
ElicitRequest(ElicitRequest),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -1289,21 +1172,6 @@ pub struct SetLevelRequestParams {
|
||||
pub level: LoggingLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct StringSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<String>,
|
||||
#[serde(rename = "maxLength", default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_length: Option<i64>,
|
||||
#[serde(rename = "minLength", default, skip_serializing_if = "Option::is_none")]
|
||||
pub min_length: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "string"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum SubscribeRequest {}
|
||||
|
||||
@@ -1345,25 +1213,6 @@ pub struct Tool {
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: ToolInputSchema,
|
||||
pub name: String,
|
||||
#[serde(
|
||||
rename = "outputSchema",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub output_schema: Option<ToolOutputSchema>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
/// An optional JSON Schema object defining the structure of the tool's output returned in
|
||||
/// the structuredContent field of a CallToolResult.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ToolOutputSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub properties: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<Vec<String>>,
|
||||
pub r#type: String, // &'static str = "object"
|
||||
}
|
||||
|
||||
/// A JSON Schema object defining the expected parameters for the tool.
|
||||
|
||||
@@ -17,8 +17,8 @@ fn deserialize_initialize_request() {
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"capabilities": {},
|
||||
"clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-06-18"
|
||||
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-03-26"
|
||||
}
|
||||
}"#;
|
||||
|
||||
@@ -37,8 +37,8 @@ fn deserialize_initialize_request() {
|
||||
method: "initialize".into(),
|
||||
params: Some(json!({
|
||||
"capabilities": {},
|
||||
"clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-06-18"
|
||||
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-03-26"
|
||||
})),
|
||||
};
|
||||
|
||||
@@ -57,14 +57,12 @@ fn deserialize_initialize_request() {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "acme-client".into(),
|
||||
title: Some("Acme".to_string()),
|
||||
version: "1.2.3".into(),
|
||||
},
|
||||
protocol_version: "2025-06-18".into(),
|
||||
protocol_version: "2025-03-26".into(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ anyhow = "1"
|
||||
base64 = "0.22.1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-ansi-escape = { path = "../ansi-escape" }
|
||||
codex-arg0 = { path = "../arg0" }
|
||||
codex-core = { path = "../core" }
|
||||
codex-common = { path = "../common", features = [
|
||||
"cli",
|
||||
@@ -27,6 +26,7 @@ codex-common = { path = "../common", features = [
|
||||
"sandbox_summary",
|
||||
] }
|
||||
codex-file-search = { path = "../file-search" }
|
||||
codex-linux-sandbox = { path = "../linux-sandbox" }
|
||||
codex-login = { path = "../login" }
|
||||
color-eyre = "0.6.3"
|
||||
crossterm = { version = "0.28.1", features = ["bracketed-paste"] }
|
||||
@@ -35,16 +35,15 @@ lazy_static = "1"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
path-clean = "1.0.1"
|
||||
ratatui = { version = "0.29.0", features = [
|
||||
"scrolling-regions",
|
||||
"unstable-rendered-line-info",
|
||||
"unstable-widget-ref",
|
||||
"unstable-rendered-line-info",
|
||||
] }
|
||||
ratatui-image = "8.0.0"
|
||||
regex-lite = "0.1"
|
||||
serde_json = { version = "1", features = ["preserve_order"] }
|
||||
shlex = "1.3.0"
|
||||
strum = "0.27.2"
|
||||
strum_macros = "0.27.2"
|
||||
strum = "0.27.1"
|
||||
strum_macros = "0.27.1"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -59,10 +58,8 @@ tui-input = "0.14.0"
|
||||
tui-markdown = "0.3.3"
|
||||
tui-textarea = "0.7.0"
|
||||
unicode-segmentation = "1.12.0"
|
||||
unicode-width = "0.1"
|
||||
uuid = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
insta = "1.43.1"
|
||||
pretty_assertions = "1"
|
||||
tempfile = "3.13.0"
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::get_git_diff::get_git_diff;
|
||||
use crate::git_warning_screen::GitWarningOutcome;
|
||||
use crate::git_warning_screen::GitWarningScreen;
|
||||
use crate::login_screen::LoginScreen;
|
||||
use crate::mouse_capture::MouseCapture;
|
||||
use crate::scroll_event_helper::ScrollEventHelper;
|
||||
use crate::slash_command::SlashCommand;
|
||||
use crate::tui;
|
||||
@@ -18,8 +19,7 @@ use crossterm::event::MouseEvent;
|
||||
use crossterm::event::MouseEventKind;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::thread;
|
||||
@@ -54,7 +54,7 @@ pub(crate) struct App<'a> {
|
||||
file_search: FileSearchManager,
|
||||
|
||||
/// True when a redraw has been scheduled but not yet executed.
|
||||
pending_redraw: Arc<AtomicBool>,
|
||||
pending_redraw: Arc<Mutex<bool>>,
|
||||
|
||||
/// Stored parameters needed to instantiate the ChatWidget later, e.g.,
|
||||
/// after dismissing the Git-repo warning.
|
||||
@@ -68,7 +68,6 @@ struct ChatWidgetArgs {
|
||||
config: Config,
|
||||
initial_prompt: Option<String>,
|
||||
initial_images: Vec<PathBuf>,
|
||||
prompt_label: Option<String>,
|
||||
}
|
||||
|
||||
impl App<'_> {
|
||||
@@ -78,11 +77,10 @@ impl App<'_> {
|
||||
show_login_screen: bool,
|
||||
show_git_warning: bool,
|
||||
initial_images: Vec<std::path::PathBuf>,
|
||||
prompt_label: Option<String>,
|
||||
) -> Self {
|
||||
let (app_event_tx, app_event_rx) = channel();
|
||||
let app_event_tx = AppEventSender::new(app_event_tx);
|
||||
let pending_redraw = Arc::new(AtomicBool::new(false));
|
||||
let pending_redraw = Arc::new(Mutex::new(false));
|
||||
let scroll_event_helper = ScrollEventHelper::new(app_event_tx.clone());
|
||||
|
||||
// Spawn a dedicated thread for reading the crossterm event loop and
|
||||
@@ -90,51 +88,32 @@ impl App<'_> {
|
||||
{
|
||||
let app_event_tx = app_event_tx.clone();
|
||||
std::thread::spawn(move || {
|
||||
loop {
|
||||
// This timeout is necessary to avoid holding the event lock
|
||||
// that crossterm::event::read() acquires. In particular,
|
||||
// reading the cursor position (crossterm::cursor::position())
|
||||
// needs to acquire the event lock, and so will fail if it
|
||||
// can't acquire it within 2 sec. Resizing the terminal
|
||||
// crashes the app if the cursor position can't be read.
|
||||
if let Ok(true) = crossterm::event::poll(Duration::from_millis(100)) {
|
||||
if let Ok(event) = crossterm::event::read() {
|
||||
match event {
|
||||
crossterm::event::Event::Key(key_event) => {
|
||||
app_event_tx.send(AppEvent::KeyEvent(key_event));
|
||||
}
|
||||
crossterm::event::Event::Resize(_, _) => {
|
||||
app_event_tx.send(AppEvent::RequestRedraw);
|
||||
}
|
||||
crossterm::event::Event::Mouse(MouseEvent {
|
||||
kind: MouseEventKind::ScrollUp,
|
||||
..
|
||||
}) => {
|
||||
scroll_event_helper.scroll_up();
|
||||
}
|
||||
crossterm::event::Event::Mouse(MouseEvent {
|
||||
kind: MouseEventKind::ScrollDown,
|
||||
..
|
||||
}) => {
|
||||
scroll_event_helper.scroll_down();
|
||||
}
|
||||
crossterm::event::Event::Paste(pasted) => {
|
||||
// Many terminals convert newlines to \r when
|
||||
// pasting, e.g. [iTerm2][]. But [tui-textarea
|
||||
// expects \n][tui-textarea]. This seems like a bug
|
||||
// in tui-textarea IMO, but work around it for now.
|
||||
// [tui-textarea]: https://github.com/rhysd/tui-textarea/blob/4d18622eeac13b309e0ff6a55a46ac6706da68cf/src/textarea.rs#L782-L783
|
||||
// [iTerm2]: https://github.com/gnachman/iTerm2/blob/5d0c0d9f68523cbd0494dad5422998964a2ecd8d/sources/iTermPasteHelper.m#L206-L216
|
||||
let pasted = pasted.replace("\r", "\n");
|
||||
app_event_tx.send(AppEvent::Paste(pasted));
|
||||
}
|
||||
_ => {
|
||||
// Ignore any other events.
|
||||
}
|
||||
}
|
||||
while let Ok(event) = crossterm::event::read() {
|
||||
match event {
|
||||
crossterm::event::Event::Key(key_event) => {
|
||||
app_event_tx.send(AppEvent::KeyEvent(key_event));
|
||||
}
|
||||
crossterm::event::Event::Resize(_, _) => {
|
||||
app_event_tx.send(AppEvent::RequestRedraw);
|
||||
}
|
||||
crossterm::event::Event::Mouse(MouseEvent {
|
||||
kind: MouseEventKind::ScrollUp,
|
||||
..
|
||||
}) => {
|
||||
scroll_event_helper.scroll_up();
|
||||
}
|
||||
crossterm::event::Event::Mouse(MouseEvent {
|
||||
kind: MouseEventKind::ScrollDown,
|
||||
..
|
||||
}) => {
|
||||
scroll_event_helper.scroll_down();
|
||||
}
|
||||
crossterm::event::Event::Paste(pasted) => {
|
||||
app_event_tx.send(AppEvent::Paste(pasted));
|
||||
}
|
||||
_ => {
|
||||
// Ignore any other events.
|
||||
}
|
||||
} else {
|
||||
// Timeout expired, no `Event` is available
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -149,7 +128,6 @@ impl App<'_> {
|
||||
config: config.clone(),
|
||||
initial_prompt,
|
||||
initial_images,
|
||||
prompt_label: prompt_label.clone(),
|
||||
}),
|
||||
)
|
||||
} else if show_git_warning {
|
||||
@@ -161,7 +139,6 @@ impl App<'_> {
|
||||
config: config.clone(),
|
||||
initial_prompt,
|
||||
initial_images,
|
||||
prompt_label: prompt_label.clone(),
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
@@ -170,7 +147,6 @@ impl App<'_> {
|
||||
app_event_tx.clone(),
|
||||
initial_prompt,
|
||||
initial_images,
|
||||
prompt_label.clone(),
|
||||
);
|
||||
(
|
||||
AppState::Chat {
|
||||
@@ -201,14 +177,13 @@ impl App<'_> {
|
||||
/// Schedule a redraw if one is not already pending.
|
||||
#[allow(clippy::unwrap_used)]
|
||||
fn schedule_redraw(&self) {
|
||||
// Attempt to set the flag to `true`. If it was already `true`, another
|
||||
// redraw is already pending so we can return early.
|
||||
if self
|
||||
.pending_redraw
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let mut flag = self.pending_redraw.lock().unwrap();
|
||||
if *flag {
|
||||
return;
|
||||
}
|
||||
*flag = true;
|
||||
}
|
||||
|
||||
let tx = self.app_event_tx.clone();
|
||||
@@ -216,21 +191,23 @@ impl App<'_> {
|
||||
thread::spawn(move || {
|
||||
thread::sleep(REDRAW_DEBOUNCE);
|
||||
tx.send(AppEvent::Redraw);
|
||||
pending_redraw.store(false, Ordering::SeqCst);
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let mut f = pending_redraw.lock().unwrap();
|
||||
*f = false;
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn run(&mut self, terminal: &mut tui::Tui) -> Result<()> {
|
||||
pub(crate) fn run(
|
||||
&mut self,
|
||||
terminal: &mut tui::Tui,
|
||||
mouse_capture: &mut MouseCapture,
|
||||
) -> Result<()> {
|
||||
// Insert an event to trigger the first render.
|
||||
let app_event_tx = self.app_event_tx.clone();
|
||||
app_event_tx.send(AppEvent::RequestRedraw);
|
||||
|
||||
while let Ok(event) = self.app_event_rx.recv() {
|
||||
match event {
|
||||
AppEvent::InsertHistory(lines) => {
|
||||
crate::insert_history::insert_history_lines(terminal, lines);
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
}
|
||||
AppEvent::RequestRedraw => {
|
||||
self.schedule_redraw();
|
||||
}
|
||||
@@ -246,7 +223,9 @@ impl App<'_> {
|
||||
} => {
|
||||
match &mut self.app_state {
|
||||
AppState::Chat { widget } => {
|
||||
widget.on_ctrl_c();
|
||||
if widget.on_ctrl_c() {
|
||||
self.app_event_tx.send(AppEvent::ExitRequest);
|
||||
}
|
||||
}
|
||||
AppState::Login { .. } | AppState::GitWarning { .. } => {
|
||||
// No-op.
|
||||
@@ -306,11 +285,15 @@ impl App<'_> {
|
||||
self.app_event_tx.clone(),
|
||||
None,
|
||||
Vec::new(),
|
||||
None,
|
||||
));
|
||||
self.app_state = AppState::Chat { widget: new_widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
}
|
||||
SlashCommand::ToggleMouseMode => {
|
||||
if let Err(e) = mouse_capture.toggle() {
|
||||
tracing::error!("Failed to toggle mouse mode: {e}");
|
||||
}
|
||||
}
|
||||
SlashCommand::Quit => {
|
||||
break;
|
||||
}
|
||||
@@ -351,15 +334,6 @@ impl App<'_> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn token_usage(&self) -> codex_core::protocol::TokenUsage {
|
||||
match &self.app_state {
|
||||
AppState::Chat { widget } => widget.token_usage().clone(),
|
||||
AppState::Login { .. } | AppState::GitWarning { .. } => {
|
||||
codex_core::protocol::TokenUsage::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn draw_next_frame(&mut self, terminal: &mut tui::Tui) -> Result<()> {
|
||||
// TODO: add a throttle to avoid redrawing too often
|
||||
|
||||
@@ -398,7 +372,6 @@ impl App<'_> {
|
||||
self.app_event_tx.clone(),
|
||||
args.initial_prompt,
|
||||
args.initial_images,
|
||||
args.prompt_label,
|
||||
));
|
||||
self.app_state = AppState::Chat { widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use codex_core::protocol::Event;
|
||||
use codex_file_search::FileMatch;
|
||||
use crossterm::event::KeyEvent;
|
||||
use ratatui::text::Line;
|
||||
|
||||
use crate::slash_command::SlashCommand;
|
||||
|
||||
@@ -50,6 +49,4 @@ pub(crate) enum AppEvent {
|
||||
query: String,
|
||||
matches: Vec<FileMatch>,
|
||||
},
|
||||
|
||||
InsertHistory(Vec<Line<'static>>),
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ use crate::user_approval_widget::UserApprovalWidget;
|
||||
|
||||
use super::BottomPane;
|
||||
use super::BottomPaneView;
|
||||
use super::CancellationEvent;
|
||||
|
||||
/// Modal overlay asking the user to approve/deny a sequence of requests.
|
||||
pub(crate) struct ApprovalModalView<'a> {
|
||||
@@ -47,16 +46,14 @@ impl<'a> BottomPaneView<'a> for ApprovalModalView<'a> {
|
||||
self.maybe_advance();
|
||||
}
|
||||
|
||||
fn on_ctrl_c(&mut self, _pane: &mut BottomPane<'a>) -> CancellationEvent {
|
||||
self.current.on_ctrl_c();
|
||||
self.queue.clear();
|
||||
CancellationEvent::Handled
|
||||
}
|
||||
|
||||
fn is_complete(&self) -> bool {
|
||||
self.current.is_complete() && self.queue.is_empty()
|
||||
}
|
||||
|
||||
fn calculate_required_height(&self, area: &Rect) -> u16 {
|
||||
self.current.get_height(area)
|
||||
}
|
||||
|
||||
fn render(&self, area: Rect, buf: &mut Buffer) {
|
||||
(&self.current).render_ref(area, buf);
|
||||
}
|
||||
@@ -66,39 +63,3 @@ impl<'a> BottomPaneView<'a> for ApprovalModalView<'a> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::app_event::AppEvent;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
fn make_exec_request() -> ApprovalRequest {
|
||||
ApprovalRequest::Exec {
|
||||
id: "test".to_string(),
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
cwd: PathBuf::from("/tmp"),
|
||||
reason: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ctrl_c_aborts_and_clears_queue() {
|
||||
let (tx_raw, _rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let first = make_exec_request();
|
||||
let mut view = ApprovalModalView::new(first, tx);
|
||||
view.enqueue_request(make_exec_request());
|
||||
|
||||
let (tx_raw2, _rx2) = channel::<AppEvent>();
|
||||
let mut pane = BottomPane::new(super::super::BottomPaneParams {
|
||||
app_event_tx: AppEventSender::new(tx_raw2),
|
||||
has_input_focus: true,
|
||||
});
|
||||
assert_eq!(CancellationEvent::Handled, view.on_ctrl_c(&mut pane));
|
||||
assert!(view.queue.is_empty());
|
||||
assert!(view.current.is_complete());
|
||||
assert!(view.is_complete());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Rect;
|
||||
|
||||
use super::BottomPane;
|
||||
use super::CancellationEvent;
|
||||
|
||||
/// Type to use for a method that may require a redraw of the UI.
|
||||
pub(crate) enum ConditionalUpdate {
|
||||
@@ -23,10 +22,8 @@ pub(crate) trait BottomPaneView<'a> {
|
||||
false
|
||||
}
|
||||
|
||||
/// Handle Ctrl-C while this view is active.
|
||||
fn on_ctrl_c(&mut self, _pane: &mut BottomPane<'a>) -> CancellationEvent {
|
||||
CancellationEvent::Ignored
|
||||
}
|
||||
/// Height required to render the view.
|
||||
fn calculate_required_height(&self, area: &Rect) -> u16;
|
||||
|
||||
/// Render the view: this will be displayed in place of the composer.
|
||||
fn render(&self, area: Rect, buf: &mut Buffer);
|
||||
|
||||
@@ -22,6 +22,11 @@ use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
use codex_file_search::FileMatch;
|
||||
|
||||
/// Minimum number of visible text rows inside the textarea.
|
||||
const MIN_TEXTAREA_ROWS: usize = 1;
|
||||
/// Rows consumed by the border.
|
||||
const BORDER_LINES: u16 = 2;
|
||||
|
||||
const BASE_PLACEHOLDER_TEXT: &str = "send a message";
|
||||
/// If the pasted content exceeds this number of characters, replace it with a
|
||||
/// placeholder in the UI.
|
||||
@@ -127,6 +132,10 @@ impl ChatComposer<'_> {
|
||||
.on_entry_response(log_id, offset, entry, &mut self.textarea)
|
||||
}
|
||||
|
||||
pub fn set_input_focus(&mut self, has_focus: bool) {
|
||||
self.update_border(has_focus);
|
||||
}
|
||||
|
||||
pub fn handle_paste(&mut self, pasted: String) -> bool {
|
||||
let char_count = pasted.chars().count();
|
||||
if char_count > LARGE_PASTE_CHAR_THRESHOLD {
|
||||
@@ -600,6 +609,17 @@ impl ChatComposer<'_> {
|
||||
self.dismissed_file_popup_token = None;
|
||||
}
|
||||
|
||||
pub fn calculate_required_height(&self, area: &Rect) -> u16 {
|
||||
let rows = self.textarea.lines().len().max(MIN_TEXTAREA_ROWS);
|
||||
let num_popup_rows = match &self.active_popup {
|
||||
ActivePopup::Command(popup) => popup.calculate_required_height(area),
|
||||
ActivePopup::File(popup) => popup.calculate_required_height(area),
|
||||
ActivePopup::None => 0,
|
||||
};
|
||||
|
||||
rows as u16 + BORDER_LINES + num_popup_rows
|
||||
}
|
||||
|
||||
fn update_border(&mut self, has_focus: bool) {
|
||||
struct BlockState {
|
||||
right_title: Line<'static>,
|
||||
@@ -634,6 +654,13 @@ impl ChatComposer<'_> {
|
||||
.border_style(bs.border_style),
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn is_popup_visible(&self) -> bool {
|
||||
match self.active_popup {
|
||||
ActivePopup::Command(_) | ActivePopup::File(_) => true,
|
||||
ActivePopup::None => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for &ChatComposer<'_> {
|
||||
|
||||
@@ -20,12 +20,6 @@ mod command_popup;
|
||||
mod file_search_popup;
|
||||
mod status_indicator_view;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum CancellationEvent {
|
||||
Ignored,
|
||||
Handled,
|
||||
}
|
||||
|
||||
pub(crate) use chat_composer::ChatComposer;
|
||||
pub(crate) use chat_composer::InputResult;
|
||||
|
||||
@@ -71,8 +65,10 @@ impl BottomPane<'_> {
|
||||
if !view.is_complete() {
|
||||
self.active_view = Some(view);
|
||||
} else if self.is_task_running {
|
||||
let height = self.composer.calculate_required_height(&Rect::default());
|
||||
self.active_view = Some(Box::new(StatusIndicatorView::new(
|
||||
self.app_event_tx.clone(),
|
||||
height,
|
||||
)));
|
||||
}
|
||||
self.request_redraw();
|
||||
@@ -86,33 +82,6 @@ impl BottomPane<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle Ctrl-C in the bottom pane. If a modal view is active it gets a
|
||||
/// chance to consume the event (e.g. to dismiss itself).
|
||||
pub(crate) fn on_ctrl_c(&mut self) -> CancellationEvent {
|
||||
let mut view = match self.active_view.take() {
|
||||
Some(view) => view,
|
||||
None => return CancellationEvent::Ignored,
|
||||
};
|
||||
|
||||
let event = view.on_ctrl_c(self);
|
||||
match event {
|
||||
CancellationEvent::Handled => {
|
||||
if !view.is_complete() {
|
||||
self.active_view = Some(view);
|
||||
} else if self.is_task_running {
|
||||
self.active_view = Some(Box::new(StatusIndicatorView::new(
|
||||
self.app_event_tx.clone(),
|
||||
)));
|
||||
}
|
||||
self.show_ctrl_c_quit_hint();
|
||||
}
|
||||
CancellationEvent::Ignored => {
|
||||
self.active_view = Some(view);
|
||||
}
|
||||
}
|
||||
event
|
||||
}
|
||||
|
||||
pub fn handle_paste(&mut self, pasted: String) {
|
||||
if self.active_view.is_none() {
|
||||
let needs_redraw = self.composer.handle_paste(pasted);
|
||||
@@ -137,6 +106,12 @@ impl BottomPane<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the UI to reflect whether this `BottomPane` has input focus.
|
||||
pub(crate) fn set_input_focus(&mut self, has_focus: bool) {
|
||||
self.has_input_focus = has_focus;
|
||||
self.composer.set_input_focus(has_focus);
|
||||
}
|
||||
|
||||
pub(crate) fn show_ctrl_c_quit_hint(&mut self) {
|
||||
self.ctrl_c_quit_hint = true;
|
||||
self.composer
|
||||
@@ -163,8 +138,10 @@ impl BottomPane<'_> {
|
||||
match (running, self.active_view.is_some()) {
|
||||
(true, false) => {
|
||||
// Show status indicator overlay.
|
||||
let height = self.composer.calculate_required_height(&Rect::default());
|
||||
self.active_view = Some(Box::new(StatusIndicatorView::new(
|
||||
self.app_event_tx.clone(),
|
||||
height,
|
||||
)));
|
||||
self.request_redraw();
|
||||
}
|
||||
@@ -226,10 +203,23 @@ impl BottomPane<'_> {
|
||||
}
|
||||
|
||||
/// Height (terminal rows) required by the current bottom pane.
|
||||
pub fn calculate_required_height(&self, area: &Rect) -> u16 {
|
||||
if let Some(view) = &self.active_view {
|
||||
view.calculate_required_height(area)
|
||||
} else {
|
||||
self.composer.calculate_required_height(area)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn request_redraw(&self) {
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw)
|
||||
}
|
||||
|
||||
/// Returns true when a popup inside the composer is visible.
|
||||
pub(crate) fn is_popup_visible(&self) -> bool {
|
||||
self.active_view.is_none() && self.composer.is_popup_visible()
|
||||
}
|
||||
|
||||
// --- History helpers ---
|
||||
|
||||
pub(crate) fn set_history_metadata(&mut self, log_id: u64, entry_count: usize) {
|
||||
@@ -267,34 +257,3 @@ impl WidgetRef for &BottomPane<'_> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::app_event::AppEvent;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
fn exec_request() -> ApprovalRequest {
|
||||
ApprovalRequest::Exec {
|
||||
id: "1".to_string(),
|
||||
command: vec!["echo".into(), "ok".into()],
|
||||
cwd: PathBuf::from("."),
|
||||
reason: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ctrl_c_on_modal_consumes_and_shows_quit_hint() {
|
||||
let (tx_raw, _rx) = channel::<AppEvent>();
|
||||
let tx = AppEventSender::new(tx_raw);
|
||||
let mut pane = BottomPane::new(BottomPaneParams {
|
||||
app_event_tx: tx,
|
||||
has_input_focus: true,
|
||||
});
|
||||
pane.push_approval_request(exec_request());
|
||||
assert_eq!(CancellationEvent::Handled, pane.on_ctrl_c());
|
||||
assert!(pane.ctrl_c_quit_hint_visible());
|
||||
assert_eq!(CancellationEvent::Ignored, pane.on_ctrl_c());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Rect;
|
||||
use ratatui::widgets::WidgetRef;
|
||||
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
@@ -12,9 +13,9 @@ pub(crate) struct StatusIndicatorView {
|
||||
}
|
||||
|
||||
impl StatusIndicatorView {
|
||||
pub fn new(app_event_tx: AppEventSender) -> Self {
|
||||
pub fn new(app_event_tx: AppEventSender, height: u16) -> Self {
|
||||
Self {
|
||||
view: StatusIndicatorWidget::new(app_event_tx),
|
||||
view: StatusIndicatorWidget::new(app_event_tx, height),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +34,11 @@ impl BottomPaneView<'_> for StatusIndicatorView {
|
||||
true
|
||||
}
|
||||
|
||||
fn render(&self, area: ratatui::layout::Rect, buf: &mut Buffer) {
|
||||
fn calculate_required_height(&self, _area: &Rect) -> u16 {
|
||||
self.view.get_height()
|
||||
}
|
||||
|
||||
fn render(&self, area: Rect, buf: &mut Buffer) {
|
||||
self.view.render_ref(area, buf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::codex_wrapper::CodexConversation;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::AgentMessageDeltaEvent;
|
||||
@@ -24,6 +23,9 @@ use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use crossterm::event::KeyEvent;
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Constraint;
|
||||
use ratatui::layout::Direction;
|
||||
use ratatui::layout::Layout;
|
||||
use ratatui::layout::Rect;
|
||||
use ratatui::widgets::Widget;
|
||||
use ratatui::widgets::WidgetRef;
|
||||
@@ -34,10 +36,8 @@ use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
use crate::bottom_pane::BottomPane;
|
||||
use crate::bottom_pane::BottomPaneParams;
|
||||
use crate::bottom_pane::CancellationEvent;
|
||||
use crate::bottom_pane::InputResult;
|
||||
use crate::conversation_history_widget::ConversationHistoryWidget;
|
||||
use crate::exec_command::strip_bash_lc_and_escape;
|
||||
use crate::history_cell::PatchEventType;
|
||||
use crate::user_approval_widget::ApprovalRequest;
|
||||
use codex_file_search::FileMatch;
|
||||
@@ -47,15 +47,19 @@ pub(crate) struct ChatWidget<'a> {
|
||||
codex_op_tx: UnboundedSender<Op>,
|
||||
conversation_history: ConversationHistoryWidget,
|
||||
bottom_pane: BottomPane<'a>,
|
||||
input_focus: InputFocus,
|
||||
config: Config,
|
||||
initial_user_message: Option<UserMessage>,
|
||||
token_usage: TokenUsage,
|
||||
reasoning_buffer: String,
|
||||
// Buffer for streaming assistant answer text; we do not surface partial
|
||||
// We wait for the final AgentMessage event and then emit the full text
|
||||
// at once into scrollback so the history contains a single message.
|
||||
answer_buffer: String,
|
||||
prompt_label: Option<String>,
|
||||
active_task_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Eq, PartialEq)]
|
||||
enum InputFocus {
|
||||
HistoryPane,
|
||||
BottomPane,
|
||||
}
|
||||
|
||||
struct UserMessage {
|
||||
@@ -86,7 +90,6 @@ impl ChatWidget<'_> {
|
||||
app_event_tx: AppEventSender,
|
||||
initial_prompt: Option<String>,
|
||||
initial_images: Vec<PathBuf>,
|
||||
prompt_label: Option<String>,
|
||||
) -> Self {
|
||||
let (codex_op_tx, mut codex_op_rx) = unbounded_channel::<Op>();
|
||||
|
||||
@@ -94,11 +97,7 @@ impl ChatWidget<'_> {
|
||||
// Create the Codex asynchronously so the UI loads as quickly as possible.
|
||||
let config_for_agent_loop = config.clone();
|
||||
tokio::spawn(async move {
|
||||
let CodexConversation {
|
||||
codex,
|
||||
session_configured,
|
||||
..
|
||||
} = match init_codex(config_for_agent_loop).await {
|
||||
let (codex, session_event, _ctrl_c) = match init_codex(config_for_agent_loop).await {
|
||||
Ok(vals) => vals,
|
||||
Err(e) => {
|
||||
// TODO: surface this error to the user.
|
||||
@@ -109,7 +108,7 @@ impl ChatWidget<'_> {
|
||||
|
||||
// Forward the captured `SessionInitialized` event that was consumed
|
||||
// inside `init_codex()` so it can be rendered in the UI.
|
||||
app_event_tx_clone.send(AppEvent::CodexEvent(session_configured.clone()));
|
||||
app_event_tx_clone.send(AppEvent::CodexEvent(session_event.clone()));
|
||||
let codex = Arc::new(codex);
|
||||
let codex_clone = codex.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -134,6 +133,7 @@ impl ChatWidget<'_> {
|
||||
app_event_tx,
|
||||
has_input_focus: true,
|
||||
}),
|
||||
input_focus: InputFocus::BottomPane,
|
||||
config,
|
||||
initial_user_message: create_initial_user_message(
|
||||
initial_prompt.unwrap_or_default(),
|
||||
@@ -142,29 +142,49 @@ impl ChatWidget<'_> {
|
||||
token_usage: TokenUsage::default(),
|
||||
reasoning_buffer: String::new(),
|
||||
answer_buffer: String::new(),
|
||||
prompt_label,
|
||||
active_task_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) {
|
||||
self.bottom_pane.clear_ctrl_c_quit_hint();
|
||||
// Special-case <Tab>: normally toggles focus between history and bottom panes.
|
||||
// However, when the slash-command popup is visible we forward the key
|
||||
// to the bottom pane so it can handle auto-completion.
|
||||
if matches!(key_event.code, crossterm::event::KeyCode::Tab)
|
||||
&& !self.bottom_pane.is_popup_visible()
|
||||
{
|
||||
self.input_focus = match self.input_focus {
|
||||
InputFocus::HistoryPane => InputFocus::BottomPane,
|
||||
InputFocus::BottomPane => InputFocus::HistoryPane,
|
||||
};
|
||||
self.conversation_history
|
||||
.set_input_focus(self.input_focus == InputFocus::HistoryPane);
|
||||
self.bottom_pane
|
||||
.set_input_focus(self.input_focus == InputFocus::BottomPane);
|
||||
self.request_redraw();
|
||||
return;
|
||||
}
|
||||
|
||||
match self.bottom_pane.handle_key_event(key_event) {
|
||||
InputResult::Submitted(text) => {
|
||||
self.submit_user_message(text.into());
|
||||
match self.input_focus {
|
||||
InputFocus::HistoryPane => {
|
||||
let needs_redraw = self.conversation_history.handle_key_event(key_event);
|
||||
if needs_redraw {
|
||||
self.request_redraw();
|
||||
}
|
||||
}
|
||||
InputResult::None => {}
|
||||
InputFocus::BottomPane => match self.bottom_pane.handle_key_event(key_event) {
|
||||
InputResult::Submitted(text) => {
|
||||
self.submit_user_message(text.into());
|
||||
}
|
||||
InputResult::None => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_paste(&mut self, text: String) {
|
||||
self.bottom_pane.handle_paste(text);
|
||||
}
|
||||
|
||||
/// Emits the last entry's plain lines from conversation_history, if any.
|
||||
fn emit_last_history_entry(&mut self) {
|
||||
if let Some(lines) = self.conversation_history.last_entry_plain_lines() {
|
||||
self.app_event_tx.send(AppEvent::InsertHistory(lines));
|
||||
if matches!(self.input_focus, InputFocus::BottomPane) {
|
||||
self.bottom_pane.handle_paste(text);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,26 +221,38 @@ impl ChatWidget<'_> {
|
||||
|
||||
// Only show text portion in conversation history for now.
|
||||
if !text.is_empty() {
|
||||
self.conversation_history.add_user_message(text.clone());
|
||||
self.emit_last_history_entry();
|
||||
self.conversation_history.add_user_message(text);
|
||||
}
|
||||
self.conversation_history.scroll_to_bottom();
|
||||
|
||||
// IMPORTANT: Starting a *new* user turn. Clear any partially streamed
|
||||
// answer from a previous turn (e.g., one that was interrupted) so that
|
||||
// the next AgentMessageDelta spawns a fresh agent message cell instead
|
||||
// of overwriting the last one.
|
||||
self.answer_buffer.clear();
|
||||
self.reasoning_buffer.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn handle_codex_event(&mut self, event: Event) {
|
||||
let Event { id, msg } = event;
|
||||
// Retain the event ID so we can refer to it after destructuring.
|
||||
let event_id = event.id.clone();
|
||||
let Event { id: _, msg } = event;
|
||||
|
||||
// When we are in the middle of a task (active_task_id is Some) we drop
|
||||
// streaming text/reasoning events for *other* task IDs. This prevents
|
||||
// late tokens from an interrupted run from bleeding into the current
|
||||
// answer.
|
||||
let should_drop_streaming = self
|
||||
.active_task_id
|
||||
.as_ref()
|
||||
.map(|active| active != &event_id)
|
||||
.unwrap_or(false);
|
||||
|
||||
match msg {
|
||||
EventMsg::SessionConfigured(event) => {
|
||||
// Record session information at the top of the conversation.
|
||||
self.conversation_history.add_session_info(
|
||||
&self.config,
|
||||
event.clone(),
|
||||
self.prompt_label.as_deref(),
|
||||
);
|
||||
// Immediately surface the session banner / settings summary in
|
||||
// scrollback so the user can review configuration (model,
|
||||
// sandbox, approvals, etc.) before interacting.
|
||||
self.emit_last_history_entry();
|
||||
self.conversation_history
|
||||
.add_session_info(&self.config, event.clone());
|
||||
|
||||
// Forward history metadata to the bottom pane so the chat
|
||||
// composer can navigate through past messages.
|
||||
@@ -236,53 +268,69 @@ impl ChatWidget<'_> {
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
// Final assistant answer. Prefer the fully provided message
|
||||
// from the event; if it is empty fall back to any accumulated
|
||||
// delta buffer (some providers may only stream deltas and send
|
||||
// an empty final message).
|
||||
let full = if message.is_empty() {
|
||||
std::mem::take(&mut self.answer_buffer)
|
||||
} else {
|
||||
self.answer_buffer.clear();
|
||||
message
|
||||
};
|
||||
if !full.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_message(&self.config, full);
|
||||
self.emit_last_history_entry();
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
// if the answer buffer is empty, this means we haven't received any
|
||||
// delta. Thus, we need to print the message as a new answer.
|
||||
if self.answer_buffer.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_message(&self.config, message);
|
||||
} else {
|
||||
self.conversation_history
|
||||
.replace_prev_agent_message(&self.config, message);
|
||||
}
|
||||
self.answer_buffer.clear();
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
|
||||
// Buffer only – do not emit partial lines. This avoids cases
|
||||
// where long responses appear truncated if the terminal
|
||||
// wrapped early. The full message is emitted on
|
||||
// AgentMessage.
|
||||
self.answer_buffer.push_str(&delta);
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
if self.answer_buffer.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_message(&self.config, "".to_string());
|
||||
}
|
||||
self.answer_buffer.push_str(&delta.clone());
|
||||
self.conversation_history
|
||||
.replace_prev_agent_message(&self.config, self.answer_buffer.clone());
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => {
|
||||
// Buffer only – disable incremental reasoning streaming so we
|
||||
// avoid truncated intermediate lines. Full text emitted on
|
||||
// AgentReasoning.
|
||||
self.reasoning_buffer.push_str(&delta);
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
if self.reasoning_buffer.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_reasoning(&self.config, "".to_string());
|
||||
}
|
||||
self.reasoning_buffer.push_str(&delta.clone());
|
||||
self.conversation_history
|
||||
.replace_prev_agent_reasoning(&self.config, self.reasoning_buffer.clone());
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::AgentReasoning(AgentReasoningEvent { text }) => {
|
||||
// Emit full reasoning text once. Some providers might send
|
||||
// final event with empty text if only deltas were used.
|
||||
let full = if text.is_empty() {
|
||||
std::mem::take(&mut self.reasoning_buffer)
|
||||
} else {
|
||||
self.reasoning_buffer.clear();
|
||||
text
|
||||
};
|
||||
if !full.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_reasoning(&self.config, full);
|
||||
self.emit_last_history_entry();
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
// if the reasoning buffer is empty, this means we haven't received any
|
||||
// delta. Thus, we need to print the message as a new reasoning.
|
||||
if self.reasoning_buffer.is_empty() {
|
||||
self.conversation_history
|
||||
.add_agent_reasoning(&self.config, "".to_string());
|
||||
} else {
|
||||
// else, we rerender one last time.
|
||||
self.conversation_history
|
||||
.replace_prev_agent_reasoning(&self.config, text);
|
||||
}
|
||||
self.reasoning_buffer.clear();
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::TaskStarted => {
|
||||
// New task has begun – update state and clear any stale buffers.
|
||||
self.active_task_id = Some(event_id);
|
||||
self.answer_buffer.clear();
|
||||
self.reasoning_buffer.clear();
|
||||
self.bottom_pane.clear_ctrl_c_quit_hint();
|
||||
self.bottom_pane.set_task_running(true);
|
||||
self.request_redraw();
|
||||
@@ -290,6 +338,10 @@ impl ChatWidget<'_> {
|
||||
EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: _,
|
||||
}) => {
|
||||
// Task finished; clear active_task_id so that subsequent events are processed.
|
||||
if self.active_task_id.as_ref() == Some(&event_id) {
|
||||
self.active_task_id = None;
|
||||
}
|
||||
self.bottom_pane.set_task_running(false);
|
||||
self.request_redraw();
|
||||
}
|
||||
@@ -299,45 +351,39 @@ impl ChatWidget<'_> {
|
||||
.set_token_usage(self.token_usage.clone(), self.config.model_context_window);
|
||||
}
|
||||
EventMsg::Error(ErrorEvent { message }) => {
|
||||
self.conversation_history.add_error(message.clone());
|
||||
self.emit_last_history_entry();
|
||||
self.bottom_pane.set_task_running(false);
|
||||
// Error events always get surfaced (even for stale task IDs) so that the user sees
|
||||
// why a run stopped. However, only clear the running indicator if this is the
|
||||
// active task.
|
||||
if self.active_task_id.as_ref() == Some(&event_id) {
|
||||
self.bottom_pane.set_task_running(false);
|
||||
self.active_task_id = None;
|
||||
}
|
||||
self.conversation_history.add_error(message);
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
call_id: _,
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
}) => {
|
||||
// Print the command to the history so it is visible in the
|
||||
// transcript *before* the modal asks for approval.
|
||||
let cmdline = strip_bash_lc_and_escape(&command);
|
||||
let text = format!(
|
||||
"command requires approval:\n$ {cmdline}{reason}",
|
||||
reason = reason
|
||||
.as_ref()
|
||||
.map(|r| format!("\n{r}"))
|
||||
.unwrap_or_default()
|
||||
);
|
||||
self.conversation_history.add_background_event(text);
|
||||
self.emit_last_history_entry();
|
||||
self.conversation_history.scroll_to_bottom();
|
||||
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
let request = ApprovalRequest::Exec {
|
||||
id,
|
||||
id: event_id,
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
};
|
||||
self.bottom_pane.push_approval_request(request);
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id: _,
|
||||
changes,
|
||||
reason,
|
||||
grant_root,
|
||||
}) => {
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
// ------------------------------------------------------------------
|
||||
// Before we even prompt the user for approval we surface the patch
|
||||
// summary in the main conversation so that the dialog appears in a
|
||||
@@ -351,13 +397,12 @@ impl ChatWidget<'_> {
|
||||
|
||||
self.conversation_history
|
||||
.add_patch_event(PatchEventType::ApprovalRequest, changes);
|
||||
self.emit_last_history_entry();
|
||||
|
||||
self.conversation_history.scroll_to_bottom();
|
||||
|
||||
// Now surface the approval request in the BottomPane as before.
|
||||
let request = ApprovalRequest::ApplyPatch {
|
||||
id,
|
||||
id: event_id,
|
||||
reason,
|
||||
grant_root,
|
||||
};
|
||||
@@ -369,9 +414,11 @@ impl ChatWidget<'_> {
|
||||
command,
|
||||
cwd: _,
|
||||
}) => {
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
self.conversation_history
|
||||
.add_active_exec_command(call_id, command);
|
||||
self.emit_last_history_entry();
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
@@ -379,11 +426,13 @@ impl ChatWidget<'_> {
|
||||
auto_approved,
|
||||
changes,
|
||||
}) => {
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
// Even when a patch is auto‑approved we still display the
|
||||
// summary so the user can follow along.
|
||||
self.conversation_history
|
||||
.add_patch_event(PatchEventType::ApplyBegin { auto_approved }, changes);
|
||||
self.emit_last_history_entry();
|
||||
if !auto_approved {
|
||||
self.conversation_history.scroll_to_bottom();
|
||||
}
|
||||
@@ -395,6 +444,9 @@ impl ChatWidget<'_> {
|
||||
stdout,
|
||||
stderr,
|
||||
}) => {
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
self.conversation_history
|
||||
.record_completed_exec_command(call_id, stdout, stderr, exit_code);
|
||||
self.request_redraw();
|
||||
@@ -405,12 +457,17 @@ impl ChatWidget<'_> {
|
||||
tool,
|
||||
arguments,
|
||||
}) => {
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
self.conversation_history
|
||||
.add_active_mcp_tool_call(call_id, server, tool, arguments);
|
||||
self.emit_last_history_entry();
|
||||
self.request_redraw();
|
||||
}
|
||||
EventMsg::McpToolCallEnd(mcp_tool_call_end_event) => {
|
||||
if should_drop_streaming {
|
||||
return;
|
||||
}
|
||||
let success = mcp_tool_call_end_event.is_success();
|
||||
let McpToolCallEndEvent { call_id, result } = mcp_tool_call_end_event;
|
||||
self.conversation_history
|
||||
@@ -428,13 +485,9 @@ impl ChatWidget<'_> {
|
||||
self.bottom_pane
|
||||
.on_history_entry_response(log_id, offset, entry.map(|e| e.text));
|
||||
}
|
||||
EventMsg::ShutdownComplete => {
|
||||
self.app_event_tx.send(AppEvent::ExitRequest);
|
||||
}
|
||||
event => {
|
||||
self.conversation_history
|
||||
.add_background_event(format!("{event:?}"));
|
||||
self.emit_last_history_entry();
|
||||
self.request_redraw();
|
||||
}
|
||||
}
|
||||
@@ -451,9 +504,7 @@ impl ChatWidget<'_> {
|
||||
}
|
||||
|
||||
pub(crate) fn add_diff_output(&mut self, diff_output: String) {
|
||||
self.conversation_history
|
||||
.add_diff_output(diff_output.clone());
|
||||
self.emit_last_history_entry();
|
||||
self.conversation_history.add_diff_output(diff_output);
|
||||
self.request_redraw();
|
||||
}
|
||||
|
||||
@@ -476,25 +527,18 @@ impl ChatWidget<'_> {
|
||||
}
|
||||
|
||||
/// Handle Ctrl-C key press.
|
||||
/// Returns CancellationEvent::Handled if the event was consumed by the UI, or
|
||||
/// CancellationEvent::Ignored if the caller should handle it (e.g. exit).
|
||||
pub(crate) fn on_ctrl_c(&mut self) -> CancellationEvent {
|
||||
match self.bottom_pane.on_ctrl_c() {
|
||||
CancellationEvent::Handled => return CancellationEvent::Handled,
|
||||
CancellationEvent::Ignored => {}
|
||||
}
|
||||
/// Returns true if the key press was handled, false if it was not.
|
||||
/// If the key press was not handled, the caller should handle it (likely by exiting the process).
|
||||
pub(crate) fn on_ctrl_c(&mut self) -> bool {
|
||||
if self.bottom_pane.is_task_running() {
|
||||
self.bottom_pane.clear_ctrl_c_quit_hint();
|
||||
self.submit_op(Op::Interrupt);
|
||||
self.answer_buffer.clear();
|
||||
self.reasoning_buffer.clear();
|
||||
CancellationEvent::Ignored
|
||||
false
|
||||
} else if self.bottom_pane.ctrl_c_quit_hint_visible() {
|
||||
self.submit_op(Op::Shutdown);
|
||||
CancellationEvent::Handled
|
||||
true
|
||||
} else {
|
||||
self.bottom_pane.show_ctrl_c_quit_hint();
|
||||
CancellationEvent::Ignored
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -508,18 +552,19 @@ impl ChatWidget<'_> {
|
||||
tracing::error!("failed to submit op: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn token_usage(&self) -> &TokenUsage {
|
||||
&self.token_usage
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for &ChatWidget<'_> {
|
||||
fn render_ref(&self, area: Rect, buf: &mut Buffer) {
|
||||
// In the hybrid inline viewport mode we only draw the interactive
|
||||
// bottom pane; history entries are injected directly into scrollback
|
||||
// via `Terminal::insert_before`.
|
||||
(&self.bottom_pane).render(area, buf);
|
||||
let bottom_height = self.bottom_pane.calculate_required_height(&area);
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([Constraint::Min(0), Constraint::Length(bottom_height)])
|
||||
.split(area);
|
||||
|
||||
self.conversation_history.render(chunks[0], buf);
|
||||
(&self.bottom_pane).render(chunks[1], buf);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,38 +53,4 @@ pub struct Cli {
|
||||
|
||||
#[clap(skip)]
|
||||
pub config_overrides: CliConfigOverrides,
|
||||
|
||||
/// Override the built-in system prompt (base instructions).
|
||||
///
|
||||
/// If the value looks like a path to an existing file, the contents of the
|
||||
/// file are used. Otherwise, the value itself is used verbatim as the
|
||||
/// instructions string.
|
||||
#[arg(long = "experimental-instructions")]
|
||||
pub experimental_instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Cli;
|
||||
use clap::CommandFactory;
|
||||
|
||||
#[test]
|
||||
fn help_includes_file_behavior_for_experimental_instructions() {
|
||||
let mut cmd = Cli::command();
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
assert!(cmd.write_long_help(&mut buf).is_ok(), "help should render");
|
||||
let help = match String::from_utf8(buf) {
|
||||
Ok(s) => s,
|
||||
Err(e) => panic!("invalid utf8: {e}"),
|
||||
};
|
||||
assert!(help.contains("Override the built-in system prompt (base instructions)."));
|
||||
assert!(help.contains(
|
||||
"If the value looks like a path to an existing file, the contents of the file are used."
|
||||
));
|
||||
assert!(
|
||||
help.contains(
|
||||
"Otherwise, the value itself is used verbatim as the instructions string."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ use crate::history_cell::PatchEventType;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use ratatui::prelude::*;
|
||||
use ratatui::style::Style;
|
||||
use ratatui::widgets::*;
|
||||
@@ -45,6 +47,33 @@ impl ConversationHistoryWidget {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_input_focus(&mut self, has_input_focus: bool) {
|
||||
self.has_input_focus = has_input_focus;
|
||||
}
|
||||
|
||||
/// Returns true if it needs a redraw.
|
||||
pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) -> bool {
|
||||
match key_event.code {
|
||||
KeyCode::Up | KeyCode::Char('k') => {
|
||||
self.scroll_up(1);
|
||||
true
|
||||
}
|
||||
KeyCode::Down | KeyCode::Char('j') => {
|
||||
self.scroll_down(1);
|
||||
true
|
||||
}
|
||||
KeyCode::PageUp | KeyCode::Char('b') => {
|
||||
self.scroll_page_up();
|
||||
true
|
||||
}
|
||||
KeyCode::PageDown | KeyCode::Char(' ') => {
|
||||
self.scroll_page_down();
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Negative delta scrolls up; positive delta scrolls down.
|
||||
pub(crate) fn scroll(&mut self, delta: i32) {
|
||||
match delta.cmp(&0) {
|
||||
@@ -93,18 +122,60 @@ impl ConversationHistoryWidget {
|
||||
}
|
||||
}
|
||||
|
||||
/// Scroll up by one full viewport height (Page Up).
|
||||
fn scroll_page_up(&mut self) {
|
||||
let viewport_height = self.last_viewport_height.get().max(1);
|
||||
|
||||
// If we are currently in the "stick to bottom" mode, first convert the
|
||||
// implicit scroll position (`usize::MAX`) into an explicit offset that
|
||||
// represents the very bottom of the scroll region. This mirrors the
|
||||
// logic from `scroll_up()`.
|
||||
if self.scroll_position == usize::MAX {
|
||||
self.scroll_position = self
|
||||
.num_rendered_lines
|
||||
.get()
|
||||
.saturating_sub(viewport_height);
|
||||
}
|
||||
|
||||
// Move up by a full page.
|
||||
self.scroll_position = self.scroll_position.saturating_sub(viewport_height);
|
||||
}
|
||||
|
||||
/// Scroll down by one full viewport height (Page Down).
|
||||
fn scroll_page_down(&mut self) {
|
||||
// Nothing to do if we're already stuck to the bottom.
|
||||
if self.scroll_position == usize::MAX {
|
||||
return;
|
||||
}
|
||||
|
||||
let viewport_height = self.last_viewport_height.get().max(1);
|
||||
let num_lines = self.num_rendered_lines.get();
|
||||
|
||||
// Calculate the maximum explicit scroll offset that is still within
|
||||
// range. This matches the logic in `scroll_down()` and the render
|
||||
// method.
|
||||
let max_scroll = num_lines.saturating_sub(viewport_height);
|
||||
|
||||
// Attempt to move down by a full page.
|
||||
let new_pos = self.scroll_position.saturating_add(viewport_height);
|
||||
|
||||
if new_pos >= max_scroll {
|
||||
// We have reached (or passed) the bottom – switch back to
|
||||
// automatic stick‑to‑bottom mode so that subsequent output keeps
|
||||
// the viewport pinned.
|
||||
self.scroll_position = usize::MAX;
|
||||
} else {
|
||||
self.scroll_position = new_pos;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scroll_to_bottom(&mut self) {
|
||||
self.scroll_position = usize::MAX;
|
||||
}
|
||||
|
||||
/// Note `model` could differ from `config.model` if the agent decided to
|
||||
/// use a different model than the one requested by the user.
|
||||
pub fn add_session_info(
|
||||
&mut self,
|
||||
config: &Config,
|
||||
event: SessionConfiguredEvent,
|
||||
prompt_label: Option<&str>,
|
||||
) {
|
||||
pub fn add_session_info(&mut self, config: &Config, event: SessionConfiguredEvent) {
|
||||
// In practice, SessionConfiguredEvent should always be the first entry
|
||||
// in the history, but it is possible that an error could be sent
|
||||
// before the session info.
|
||||
@@ -116,7 +187,6 @@ impl ConversationHistoryWidget {
|
||||
config,
|
||||
event,
|
||||
!has_welcome_message,
|
||||
prompt_label,
|
||||
));
|
||||
}
|
||||
|
||||
@@ -132,6 +202,14 @@ impl ConversationHistoryWidget {
|
||||
self.add_to_history(HistoryCell::new_agent_reasoning(config, text));
|
||||
}
|
||||
|
||||
pub fn replace_prev_agent_reasoning(&mut self, config: &Config, text: String) {
|
||||
self.replace_last_agent_reasoning(config, text);
|
||||
}
|
||||
|
||||
pub fn replace_prev_agent_message(&mut self, config: &Config, text: String) {
|
||||
self.replace_last_agent_message(config, text);
|
||||
}
|
||||
|
||||
pub fn add_background_event(&mut self, message: String) {
|
||||
self.add_to_history(HistoryCell::new_background_event(message));
|
||||
}
|
||||
@@ -179,10 +257,40 @@ impl ConversationHistoryWidget {
|
||||
});
|
||||
}
|
||||
|
||||
/// Return the lines for the most recently appended entry (if any) so the
|
||||
/// parent widget can surface them via the new scrollback insertion path.
|
||||
pub(crate) fn last_entry_plain_lines(&self) -> Option<Vec<Line<'static>>> {
|
||||
self.entries.last().map(|e| e.cell.plain_lines())
|
||||
pub fn replace_last_agent_reasoning(&mut self, config: &Config, text: String) {
|
||||
if let Some(idx) = self
|
||||
.entries
|
||||
.iter()
|
||||
.rposition(|entry| matches!(entry.cell, HistoryCell::AgentReasoning { .. }))
|
||||
{
|
||||
let width = self.cached_width.get();
|
||||
let entry = &mut self.entries[idx];
|
||||
entry.cell = HistoryCell::new_agent_reasoning(config, text);
|
||||
let height = if width > 0 {
|
||||
entry.cell.height(width)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
entry.line_count.set(height);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn replace_last_agent_message(&mut self, config: &Config, text: String) {
|
||||
if let Some(idx) = self
|
||||
.entries
|
||||
.iter()
|
||||
.rposition(|entry| matches!(entry.cell, HistoryCell::AgentMessage { .. }))
|
||||
{
|
||||
let width = self.cached_width.get();
|
||||
let entry = &mut self.entries[idx];
|
||||
entry.cell = HistoryCell::new_agent_message(config, text);
|
||||
let height = if width > 0 {
|
||||
entry.cell.height(width)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
entry.line_count.set(height);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_completed_exec_command(
|
||||
|
||||
@@ -17,7 +17,6 @@ use image::GenericImageView;
|
||||
use image::ImageReader;
|
||||
use lazy_static::lazy_static;
|
||||
use mcp_types::EmbeddedResourceResource;
|
||||
use mcp_types::ResourceLink;
|
||||
use ratatui::prelude::*;
|
||||
use ratatui::style::Color;
|
||||
use ratatui::style::Modifier;
|
||||
@@ -123,35 +122,10 @@ pub(crate) enum HistoryCell {
|
||||
const TOOL_CALL_MAX_LINES: usize = 5;
|
||||
|
||||
impl HistoryCell {
|
||||
/// Return a cloned, plain representation of the cell's lines suitable for
|
||||
/// one‑shot insertion into the terminal scrollback. Image cells are
|
||||
/// represented with a simple placeholder for now.
|
||||
pub(crate) fn plain_lines(&self) -> Vec<Line<'static>> {
|
||||
match self {
|
||||
HistoryCell::WelcomeMessage { view }
|
||||
| HistoryCell::UserPrompt { view }
|
||||
| HistoryCell::AgentMessage { view }
|
||||
| HistoryCell::AgentReasoning { view }
|
||||
| HistoryCell::BackgroundEvent { view }
|
||||
| HistoryCell::GitDiffOutput { view }
|
||||
| HistoryCell::ErrorEvent { view }
|
||||
| HistoryCell::SessionInfo { view }
|
||||
| HistoryCell::CompletedExecCommand { view }
|
||||
| HistoryCell::CompletedMcpToolCall { view }
|
||||
| HistoryCell::PendingPatch { view }
|
||||
| HistoryCell::ActiveExecCommand { view, .. }
|
||||
| HistoryCell::ActiveMcpToolCall { view, .. } => view.lines.clone(),
|
||||
HistoryCell::CompletedMcpToolCallWithImageOutput { .. } => vec![
|
||||
Line::from("tool result (image output omitted)"),
|
||||
Line::from(""),
|
||||
],
|
||||
}
|
||||
}
|
||||
pub(crate) fn new_session_info(
|
||||
config: &Config,
|
||||
event: SessionConfiguredEvent,
|
||||
is_first_event: bool,
|
||||
prompt_label: Option<&str>,
|
||||
) -> Self {
|
||||
let SessionConfiguredEvent {
|
||||
model,
|
||||
@@ -181,12 +155,9 @@ impl HistoryCell {
|
||||
("workdir", config.cwd.display().to_string()),
|
||||
("model", config.model.clone()),
|
||||
("provider", config.model_provider_id.clone()),
|
||||
("approval", config.approval_policy.to_string()),
|
||||
("approval", format!("{:?}", config.approval_policy)),
|
||||
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
|
||||
];
|
||||
if let Some(label) = prompt_label {
|
||||
entries.push(("prompt", label.to_string()));
|
||||
}
|
||||
if config.model_provider.wire_api == WireApi::Responses
|
||||
&& model_supports_reasoning_summaries(config)
|
||||
{
|
||||
@@ -360,7 +331,8 @@ impl HistoryCell {
|
||||
) -> Option<Self> {
|
||||
match result {
|
||||
Ok(mcp_types::CallToolResult { content, .. }) => {
|
||||
if let Some(mcp_types::ContentBlock::ImageContent(image)) = content.first() {
|
||||
if let Some(mcp_types::CallToolResultContent::ImageContent(image)) = content.first()
|
||||
{
|
||||
let raw_data =
|
||||
match base64::engine::general_purpose::STANDARD.decode(&image.data) {
|
||||
Ok(data) => data,
|
||||
@@ -433,21 +405,21 @@ impl HistoryCell {
|
||||
|
||||
for tool_call_result in content {
|
||||
let line_text = match tool_call_result {
|
||||
mcp_types::ContentBlock::TextContent(text) => {
|
||||
mcp_types::CallToolResultContent::TextContent(text) => {
|
||||
format_and_truncate_tool_result(
|
||||
&text.text,
|
||||
TOOL_CALL_MAX_LINES,
|
||||
num_cols as usize,
|
||||
)
|
||||
}
|
||||
mcp_types::ContentBlock::ImageContent(_) => {
|
||||
mcp_types::CallToolResultContent::ImageContent(_) => {
|
||||
// TODO show images even if they're not the first result, will require a refactor of `CompletedMcpToolCall`
|
||||
"<image content>".to_string()
|
||||
}
|
||||
mcp_types::ContentBlock::AudioContent(_) => {
|
||||
mcp_types::CallToolResultContent::AudioContent(_) => {
|
||||
"<audio content>".to_string()
|
||||
}
|
||||
mcp_types::ContentBlock::EmbeddedResource(resource) => {
|
||||
mcp_types::CallToolResultContent::EmbeddedResource(resource) => {
|
||||
let uri = match resource.resource {
|
||||
EmbeddedResourceResource::TextResourceContents(text) => {
|
||||
text.uri
|
||||
@@ -458,9 +430,6 @@ impl HistoryCell {
|
||||
};
|
||||
format!("embedded resource: {uri}")
|
||||
}
|
||||
mcp_types::ContentBlock::ResourceLink(ResourceLink { uri, .. }) => {
|
||||
format!("link: {uri}")
|
||||
}
|
||||
};
|
||||
lines.push(Line::styled(line_text, Style::default().fg(Color::Gray)));
|
||||
}
|
||||
@@ -585,88 +554,6 @@ impl HistoryCell {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use uuid::Uuid;
|
||||
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn minimal_config() -> Config {
|
||||
let cwd = match TempDir::new() {
|
||||
Ok(t) => t,
|
||||
Err(e) => panic!("tempdir error: {e}"),
|
||||
};
|
||||
let codex_home = match TempDir::new() {
|
||||
Ok(t) => t,
|
||||
Err(e) => panic!("tempdir error: {e}"),
|
||||
};
|
||||
let cfg = ConfigToml {
|
||||
..Default::default()
|
||||
};
|
||||
let overrides = ConfigOverrides {
|
||||
cwd: Some(cwd.path().to_path_buf()),
|
||||
..Default::default()
|
||||
};
|
||||
match Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
overrides,
|
||||
codex_home.path().to_path_buf(),
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => panic!("config error: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn lines_to_strings(lines: &[Line<'static>]) -> Vec<String> {
|
||||
lines
|
||||
.iter()
|
||||
.map(|line| line.spans.iter().map(|s| s.content.to_string()).collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn welcome_includes_prompt_label_experimental() {
|
||||
let cfg = minimal_config();
|
||||
let event = SessionConfiguredEvent {
|
||||
session_id: Uuid::nil(),
|
||||
model: cfg.model.clone(),
|
||||
history_log_id: 0,
|
||||
history_entry_count: 0,
|
||||
};
|
||||
let cell = HistoryCell::new_session_info(&cfg, event, true, Some("experimental"));
|
||||
let lines = cell.plain_lines();
|
||||
let strings = lines_to_strings(&lines);
|
||||
assert!(
|
||||
strings.iter().any(|s| s.contains("prompt: experimental")),
|
||||
"welcome should include prompt label; got: {strings:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn welcome_includes_prompt_label_filename() {
|
||||
let cfg = minimal_config();
|
||||
let event = SessionConfiguredEvent {
|
||||
session_id: Uuid::nil(),
|
||||
model: cfg.model.clone(),
|
||||
history_log_id: 0,
|
||||
history_entry_count: 0,
|
||||
};
|
||||
let cell = HistoryCell::new_session_info(&cfg, event, true, Some("instructions.md"));
|
||||
let lines = cell.plain_lines();
|
||||
let strings = lines_to_strings(&lines);
|
||||
assert!(
|
||||
strings
|
||||
.iter()
|
||||
.any(|s| s.contains("prompt: instructions.md")),
|
||||
"welcome should include filename prompt label; got: {strings:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// `CellWidget` implementation – most variants delegate to their internal
|
||||
// `TextBlock`. Variants that need custom painting can add their own logic in
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
|
||||
use crate::tui;
|
||||
use crossterm::Command;
|
||||
use crossterm::queue;
|
||||
use crossterm::style::Color as CColor;
|
||||
use crossterm::style::Colors;
|
||||
use crossterm::style::Print;
|
||||
use crossterm::style::SetAttribute;
|
||||
use crossterm::style::SetBackgroundColor;
|
||||
use crossterm::style::SetColors;
|
||||
use crossterm::style::SetForegroundColor;
|
||||
use ratatui::layout::Position;
|
||||
use ratatui::layout::Size;
|
||||
use ratatui::prelude::Backend;
|
||||
use ratatui::style::Color;
|
||||
use ratatui::style::Modifier;
|
||||
use ratatui::text::Line;
|
||||
use ratatui::text::Span;
|
||||
|
||||
/// Insert `lines` above the viewport.
|
||||
pub(crate) fn insert_history_lines(terminal: &mut tui::Tui, lines: Vec<Line<'static>>) {
|
||||
let screen_size = terminal.backend().size().unwrap_or(Size::new(0, 0));
|
||||
|
||||
let mut area = terminal.get_frame().area();
|
||||
|
||||
let wrapped_lines = wrapped_line_count(&lines, area.width);
|
||||
let cursor_top = if area.bottom() < screen_size.height {
|
||||
// If the viewport is not at the bottom of the screen, scroll it down to make room.
|
||||
// Don't scroll it past the bottom of the screen.
|
||||
let scroll_amount = wrapped_lines.min(screen_size.height - area.bottom());
|
||||
terminal
|
||||
.backend_mut()
|
||||
.scroll_region_down(area.top()..screen_size.height, scroll_amount)
|
||||
.ok();
|
||||
let cursor_top = area.top() - 1;
|
||||
area.y += scroll_amount;
|
||||
terminal.set_viewport_area(area);
|
||||
cursor_top
|
||||
} else {
|
||||
area.top() - 1
|
||||
};
|
||||
|
||||
// Limit the scroll region to the lines from the top of the screen to the
|
||||
// top of the viewport. With this in place, when we add lines inside this
|
||||
// area, only the lines in this area will be scrolled. We place the cursor
|
||||
// at the end of the scroll region, and add lines starting there.
|
||||
//
|
||||
// ┌─Screen───────────────────────┐
|
||||
// │┌╌Scroll region╌╌╌╌╌╌╌╌╌╌╌╌╌╌┐│
|
||||
// │┆ ┆│
|
||||
// │┆ ┆│
|
||||
// │┆ ┆│
|
||||
// │█╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┘│
|
||||
// │╭─Viewport───────────────────╮│
|
||||
// ││ ││
|
||||
// │╰────────────────────────────╯│
|
||||
// └──────────────────────────────┘
|
||||
queue!(std::io::stdout(), SetScrollRegion(1..area.top())).ok();
|
||||
|
||||
terminal
|
||||
.set_cursor_position(Position::new(0, cursor_top))
|
||||
.ok();
|
||||
|
||||
for line in lines {
|
||||
queue!(std::io::stdout(), Print("\r\n")).ok();
|
||||
write_spans(&mut std::io::stdout(), line.iter()).ok();
|
||||
}
|
||||
|
||||
queue!(std::io::stdout(), ResetScrollRegion).ok();
|
||||
}
|
||||
|
||||
fn wrapped_line_count(lines: &[Line], width: u16) -> u16 {
|
||||
let mut count = 0;
|
||||
for line in lines {
|
||||
count += line_height(line, width);
|
||||
}
|
||||
count
|
||||
}
|
||||
|
||||
fn line_height(line: &Line, width: u16) -> u16 {
|
||||
use unicode_width::UnicodeWidthStr;
|
||||
// get the total display width of the line, accounting for double-width chars
|
||||
let total_width = line
|
||||
.spans
|
||||
.iter()
|
||||
.map(|span| span.content.width())
|
||||
.sum::<usize>();
|
||||
// divide by width to get the number of lines, rounding up
|
||||
if width == 0 {
|
||||
1
|
||||
} else {
|
||||
(total_width as u16).div_ceil(width).max(1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SetScrollRegion(pub std::ops::Range<u16>);
|
||||
|
||||
impl Command for SetScrollRegion {
|
||||
fn write_ansi(&self, f: &mut impl fmt::Write) -> fmt::Result {
|
||||
write!(f, "\x1b[{};{}r", self.0.start, self.0.end)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn execute_winapi(&self) -> std::io::Result<()> {
|
||||
panic!("tried to execute SetScrollRegion command using WinAPI, use ANSI instead");
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn is_ansi_code_supported(&self) -> bool {
|
||||
// TODO(nornagon): is this supported on Windows?
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct ResetScrollRegion;
|
||||
|
||||
impl Command for ResetScrollRegion {
|
||||
fn write_ansi(&self, f: &mut impl fmt::Write) -> fmt::Result {
|
||||
write!(f, "\x1b[r")
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn execute_winapi(&self) -> std::io::Result<()> {
|
||||
panic!("tried to execute ResetScrollRegion command using WinAPI, use ANSI instead");
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn is_ansi_code_supported(&self) -> bool {
|
||||
// TODO(nornagon): is this supported on Windows?
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
struct ModifierDiff {
|
||||
pub from: Modifier,
|
||||
pub to: Modifier,
|
||||
}
|
||||
|
||||
impl ModifierDiff {
|
||||
fn queue<W>(self, mut w: W) -> io::Result<()>
|
||||
where
|
||||
W: io::Write,
|
||||
{
|
||||
use crossterm::style::Attribute as CAttribute;
|
||||
let removed = self.from - self.to;
|
||||
if removed.contains(Modifier::REVERSED) {
|
||||
queue!(w, SetAttribute(CAttribute::NoReverse))?;
|
||||
}
|
||||
if removed.contains(Modifier::BOLD) {
|
||||
queue!(w, SetAttribute(CAttribute::NormalIntensity))?;
|
||||
if self.to.contains(Modifier::DIM) {
|
||||
queue!(w, SetAttribute(CAttribute::Dim))?;
|
||||
}
|
||||
}
|
||||
if removed.contains(Modifier::ITALIC) {
|
||||
queue!(w, SetAttribute(CAttribute::NoItalic))?;
|
||||
}
|
||||
if removed.contains(Modifier::UNDERLINED) {
|
||||
queue!(w, SetAttribute(CAttribute::NoUnderline))?;
|
||||
}
|
||||
if removed.contains(Modifier::DIM) {
|
||||
queue!(w, SetAttribute(CAttribute::NormalIntensity))?;
|
||||
}
|
||||
if removed.contains(Modifier::CROSSED_OUT) {
|
||||
queue!(w, SetAttribute(CAttribute::NotCrossedOut))?;
|
||||
}
|
||||
if removed.contains(Modifier::SLOW_BLINK) || removed.contains(Modifier::RAPID_BLINK) {
|
||||
queue!(w, SetAttribute(CAttribute::NoBlink))?;
|
||||
}
|
||||
|
||||
let added = self.to - self.from;
|
||||
if added.contains(Modifier::REVERSED) {
|
||||
queue!(w, SetAttribute(CAttribute::Reverse))?;
|
||||
}
|
||||
if added.contains(Modifier::BOLD) {
|
||||
queue!(w, SetAttribute(CAttribute::Bold))?;
|
||||
}
|
||||
if added.contains(Modifier::ITALIC) {
|
||||
queue!(w, SetAttribute(CAttribute::Italic))?;
|
||||
}
|
||||
if added.contains(Modifier::UNDERLINED) {
|
||||
queue!(w, SetAttribute(CAttribute::Underlined))?;
|
||||
}
|
||||
if added.contains(Modifier::DIM) {
|
||||
queue!(w, SetAttribute(CAttribute::Dim))?;
|
||||
}
|
||||
if added.contains(Modifier::CROSSED_OUT) {
|
||||
queue!(w, SetAttribute(CAttribute::CrossedOut))?;
|
||||
}
|
||||
if added.contains(Modifier::SLOW_BLINK) {
|
||||
queue!(w, SetAttribute(CAttribute::SlowBlink))?;
|
||||
}
|
||||
if added.contains(Modifier::RAPID_BLINK) {
|
||||
queue!(w, SetAttribute(CAttribute::RapidBlink))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn write_spans<'a, I>(mut writer: &mut impl Write, content: I) -> io::Result<()>
|
||||
where
|
||||
I: Iterator<Item = &'a Span<'a>>,
|
||||
{
|
||||
let mut fg = Color::Reset;
|
||||
let mut bg = Color::Reset;
|
||||
let mut modifier = Modifier::empty();
|
||||
for span in content {
|
||||
let mut next_modifier = modifier;
|
||||
next_modifier.insert(span.style.add_modifier);
|
||||
next_modifier.remove(span.style.sub_modifier);
|
||||
if next_modifier != modifier {
|
||||
let diff = ModifierDiff {
|
||||
from: modifier,
|
||||
to: next_modifier,
|
||||
};
|
||||
diff.queue(&mut writer)?;
|
||||
modifier = next_modifier;
|
||||
}
|
||||
let next_fg = span.style.fg.unwrap_or(Color::Reset);
|
||||
let next_bg = span.style.bg.unwrap_or(Color::Reset);
|
||||
if next_fg != fg || next_bg != bg {
|
||||
queue!(
|
||||
writer,
|
||||
SetColors(Colors::new(next_fg.into(), next_bg.into()))
|
||||
)?;
|
||||
fg = next_fg;
|
||||
bg = next_bg;
|
||||
}
|
||||
|
||||
queue!(writer, Print(span.content.clone()))?;
|
||||
}
|
||||
|
||||
queue!(
|
||||
writer,
|
||||
SetForegroundColor(CColor::Reset),
|
||||
SetBackgroundColor(CColor::Reset),
|
||||
SetAttribute(crossterm::style::Attribute::Reset),
|
||||
)
|
||||
}
|
||||
@@ -11,11 +11,9 @@ use codex_core::openai_api_key::get_openai_api_key;
|
||||
use codex_core::openai_api_key::set_openai_api_key;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::util::is_inside_git_repo;
|
||||
use codex_core::util::maybe_read_file;
|
||||
use codex_login::try_read_openai_api_key;
|
||||
use log_layer::TuiLogLayer;
|
||||
use std::fs::OpenOptions;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tracing_appender::non_blocking;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
@@ -35,10 +33,10 @@ mod file_search;
|
||||
mod get_git_diff;
|
||||
mod git_warning_screen;
|
||||
mod history_cell;
|
||||
mod insert_history;
|
||||
mod log_layer;
|
||||
mod login_screen;
|
||||
mod markdown;
|
||||
mod mouse_capture;
|
||||
mod scroll_event_helper;
|
||||
mod slash_command;
|
||||
mod status_indicator_widget;
|
||||
@@ -49,10 +47,7 @@ mod user_approval_widget;
|
||||
|
||||
pub use cli::Cli;
|
||||
|
||||
pub fn run_main(
|
||||
cli: Cli,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
) -> std::io::Result<codex_core::protocol::TokenUsage> {
|
||||
pub fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> std::io::Result<()> {
|
||||
let (sandbox_mode, approval_policy) = if cli.full_auto {
|
||||
(
|
||||
Some(SandboxMode::WorkspaceWrite),
|
||||
@@ -70,50 +65,8 @@ pub fn run_main(
|
||||
)
|
||||
};
|
||||
|
||||
// Capture any read error for experimental instructions so we can log it
|
||||
// after the tracing subscriber has been initialized.
|
||||
let mut experimental_read_error: Option<String> = None;
|
||||
|
||||
let (config, experimental_prompt_label) = {
|
||||
let config = {
|
||||
// Load configuration and support CLI overrides.
|
||||
// If the experimental instructions flag points at a file, read its
|
||||
// contents; otherwise use the value verbatim. Avoid printing to stdout
|
||||
// or stderr in this library crate – fallback to the raw string on
|
||||
// errors.
|
||||
let base_instructions =
|
||||
cli.experimental_instructions
|
||||
.as_deref()
|
||||
.and_then(|s| match maybe_read_file(s) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
experimental_read_error = Some(format!(
|
||||
"Failed to read experimental instructions from '{s}': {e}"
|
||||
));
|
||||
Some(s.to_string())
|
||||
}
|
||||
});
|
||||
|
||||
// Derive a label shown in the welcome banner describing the origin of
|
||||
// the experimental instructions: filename for file paths and
|
||||
// "experimental" for literals.
|
||||
let experimental_prompt_label = cli.experimental_instructions.as_deref().map(|s| {
|
||||
let p = Path::new(s);
|
||||
if p.is_file() {
|
||||
p.file_name()
|
||||
.map(|os| os.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| s.to_string())
|
||||
} else {
|
||||
"experimental".to_string()
|
||||
}
|
||||
});
|
||||
|
||||
// Do not show a label if the file was empty (base_instructions is None).
|
||||
let experimental_prompt_label = if base_instructions.is_some() {
|
||||
experimental_prompt_label
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let overrides = ConfigOverrides {
|
||||
model: cli.model.clone(),
|
||||
approval_policy,
|
||||
@@ -122,7 +75,6 @@ pub fn run_main(
|
||||
model_provider: None,
|
||||
config_profile: cli.config_profile.clone(),
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions,
|
||||
};
|
||||
// Parse `-c` overrides from the CLI.
|
||||
let cli_kv_overrides = match cli.config_overrides.parse_overrides() {
|
||||
@@ -136,7 +88,7 @@ pub fn run_main(
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
match Config::load_with_cli_overrides(cli_kv_overrides, overrides) {
|
||||
Ok(config) => (config, experimental_prompt_label),
|
||||
Ok(config) => config,
|
||||
Err(err) => {
|
||||
eprintln!("Error loading configuration: {err}");
|
||||
std::process::exit(1);
|
||||
@@ -186,12 +138,6 @@ pub fn run_main(
|
||||
.with(tui_layer)
|
||||
.try_init();
|
||||
|
||||
if let Some(msg) = experimental_read_error {
|
||||
// Now that logging is initialized, record a warning so the user
|
||||
// can see that Codex fell back to using the literal string.
|
||||
tracing::warn!("{msg}");
|
||||
}
|
||||
|
||||
let show_login_screen = should_show_login_screen(&config);
|
||||
|
||||
// Determine whether we need to display the "not a git repo" warning
|
||||
@@ -200,15 +146,24 @@ pub fn run_main(
|
||||
// `--allow-no-git-exec` flag.
|
||||
let show_git_warning = !cli.skip_git_repo_check && !is_inside_git_repo(&config);
|
||||
|
||||
run_ratatui_app(
|
||||
cli,
|
||||
config,
|
||||
show_login_screen,
|
||||
show_git_warning,
|
||||
experimental_prompt_label,
|
||||
log_rx,
|
||||
)
|
||||
.map_err(|err| std::io::Error::other(err.to_string()))
|
||||
try_run_ratatui_app(cli, config, show_login_screen, show_git_warning, log_rx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::print_stderr,
|
||||
reason = "Resort to stderr in exceptional situations."
|
||||
)]
|
||||
fn try_run_ratatui_app(
|
||||
cli: Cli,
|
||||
config: Config,
|
||||
show_login_screen: bool,
|
||||
show_git_warning: bool,
|
||||
log_rx: tokio::sync::mpsc::UnboundedReceiver<String>,
|
||||
) {
|
||||
if let Err(report) = run_ratatui_app(cli, config, show_login_screen, show_git_warning, log_rx) {
|
||||
eprintln!("Error: {report:?}");
|
||||
}
|
||||
}
|
||||
|
||||
fn run_ratatui_app(
|
||||
@@ -216,17 +171,17 @@ fn run_ratatui_app(
|
||||
config: Config,
|
||||
show_login_screen: bool,
|
||||
show_git_warning: bool,
|
||||
experimental_prompt_label: Option<String>,
|
||||
mut log_rx: tokio::sync::mpsc::UnboundedReceiver<String>,
|
||||
) -> color_eyre::Result<codex_core::protocol::TokenUsage> {
|
||||
) -> color_eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
|
||||
// Forward panic reports through tracing so they appear in the UI status
|
||||
// line instead of interleaving raw panic output with the interface.
|
||||
// Forward panic reports through the tracing stack so that they appear in
|
||||
// the status indicator instead of breaking the alternate screen – the
|
||||
// normal colour‑eyre hook writes to stderr which would corrupt the UI.
|
||||
std::panic::set_hook(Box::new(|info| {
|
||||
tracing::error!("panic: {info}");
|
||||
}));
|
||||
let mut terminal = tui::init(&config)?;
|
||||
let (mut terminal, mut mouse_capture) = tui::init(&config)?;
|
||||
terminal.clear()?;
|
||||
|
||||
let Cli { prompt, images, .. } = cli;
|
||||
@@ -236,7 +191,6 @@ fn run_ratatui_app(
|
||||
show_login_screen,
|
||||
show_git_warning,
|
||||
images,
|
||||
experimental_prompt_label,
|
||||
);
|
||||
|
||||
// Bridge log receiver into the AppEvent channel so latest log lines update the UI.
|
||||
@@ -249,12 +203,10 @@ fn run_ratatui_app(
|
||||
});
|
||||
}
|
||||
|
||||
let app_result = app.run(&mut terminal);
|
||||
let usage = app.token_usage();
|
||||
let app_result = app.run(&mut terminal, &mut mouse_capture);
|
||||
|
||||
restore();
|
||||
// ignore error when collecting usage – report underlying error instead
|
||||
app_result.map(|_| usage)
|
||||
app_result
|
||||
}
|
||||
|
||||
#[expect(
|
||||
@@ -303,56 +255,3 @@ fn is_in_need_of_openai_api_key(config: &Config) -> bool {
|
||||
.unwrap_or(false);
|
||||
is_using_openai_key && get_openai_api_key().is_none()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use codex_core::util::maybe_read_file;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn temp_path() -> PathBuf {
|
||||
let mut p = std::env::temp_dir();
|
||||
p.push(format!("codex_tui_test_{}.txt", Uuid::new_v4()));
|
||||
p
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_read_file_returns_literal_for_non_path() {
|
||||
let res = match maybe_read_file("Base instructions as a string") {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!("error: {e}"),
|
||||
};
|
||||
assert_eq!(res, Some("Base instructions as a string".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_read_file_reads_and_trims_file_contents() {
|
||||
let p = temp_path();
|
||||
if let Err(e) = fs::write(&p, " file text \n") {
|
||||
panic!("write temp file: {e}");
|
||||
}
|
||||
let p_s = p.to_string_lossy().to_string();
|
||||
let res = match maybe_read_file(&p_s) {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!("error: {e}"),
|
||||
};
|
||||
assert_eq!(res, Some("file text".to_string()));
|
||||
let _ = std::fs::remove_file(&p);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_read_file_empty_file_returns_none() {
|
||||
let p = temp_path();
|
||||
if let Err(e) = fs::write(&p, " \n\t") {
|
||||
panic!("write temp file: {e}");
|
||||
}
|
||||
let p_s = p.to_string_lossy().to_string();
|
||||
let res = match maybe_read_file(&p_s) {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!("error: {e}"),
|
||||
};
|
||||
assert_eq!(res, None);
|
||||
let _ = std::fs::remove_file(&p);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use clap::Parser;
|
||||
use codex_arg0::arg0_dispatch_or_else;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_tui::Cli;
|
||||
use codex_tui::run_main;
|
||||
@@ -14,15 +13,14 @@ struct TopCli {
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
codex_linux_sandbox::run_with_sandbox(|codex_linux_sandbox_exe| async move {
|
||||
let top_cli = TopCli::parse();
|
||||
let mut inner = top_cli.inner;
|
||||
inner
|
||||
.config_overrides
|
||||
.raw_overrides
|
||||
.splice(0..0, top_cli.config_overrides.raw_overrides);
|
||||
let usage = run_main(inner, codex_linux_sandbox_exe)?;
|
||||
println!("{}", codex_core::protocol::FinalOutput::from(usage));
|
||||
run_main(inner, codex_linux_sandbox_exe)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
69
codex-rs/tui/src/mouse_capture.rs
Normal file
69
codex-rs/tui/src/mouse_capture.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use crossterm::event::DisableMouseCapture;
|
||||
use crossterm::event::EnableMouseCapture;
|
||||
use ratatui::crossterm::execute;
|
||||
use std::io::Result;
|
||||
use std::io::stdout;
|
||||
|
||||
pub(crate) struct MouseCapture {
|
||||
mouse_capture_is_active: bool,
|
||||
}
|
||||
|
||||
impl MouseCapture {
|
||||
pub(crate) fn new_with_capture(mouse_capture_is_active: bool) -> Result<Self> {
|
||||
if mouse_capture_is_active {
|
||||
enable_capture()?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
mouse_capture_is_active,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MouseCapture {
|
||||
/// Idempotent method to set the mouse capture state.
|
||||
pub fn set_active(&mut self, is_active: bool) -> Result<()> {
|
||||
match (self.mouse_capture_is_active, is_active) {
|
||||
(true, true) => {}
|
||||
(false, false) => {}
|
||||
(true, false) => {
|
||||
disable_capture()?;
|
||||
self.mouse_capture_is_active = false;
|
||||
}
|
||||
(false, true) => {
|
||||
enable_capture()?;
|
||||
self.mouse_capture_is_active = true;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn toggle(&mut self) -> Result<()> {
|
||||
self.set_active(!self.mouse_capture_is_active)
|
||||
}
|
||||
|
||||
pub(crate) fn disable(&mut self) -> Result<()> {
|
||||
if self.mouse_capture_is_active {
|
||||
disable_capture()?;
|
||||
self.mouse_capture_is_active = false;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MouseCapture {
|
||||
fn drop(&mut self) {
|
||||
if self.disable().is_err() {
|
||||
// The user is likely shutting down, so ignore any errors so the
|
||||
// shutdown process can complete.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn enable_capture() -> Result<()> {
|
||||
execute!(stdout(), EnableMouseCapture)
|
||||
}
|
||||
|
||||
fn disable_capture() -> Result<()> {
|
||||
execute!(stdout(), DisableMouseCapture)
|
||||
}
|
||||
@@ -15,6 +15,7 @@ pub enum SlashCommand {
|
||||
New,
|
||||
Diff,
|
||||
Quit,
|
||||
ToggleMouseMode,
|
||||
}
|
||||
|
||||
impl SlashCommand {
|
||||
@@ -22,6 +23,9 @@ impl SlashCommand {
|
||||
pub fn description(self) -> &'static str {
|
||||
match self {
|
||||
SlashCommand::New => "Start a new chat.",
|
||||
SlashCommand::ToggleMouseMode => {
|
||||
"Toggle mouse mode (enable for scrolling, disable for text selection)"
|
||||
}
|
||||
SlashCommand::Quit => "Exit the application.",
|
||||
SlashCommand::Diff => {
|
||||
"Show git diff of the working directory (including untracked files)"
|
||||
|
||||
@@ -34,6 +34,11 @@ pub(crate) struct StatusIndicatorWidget {
|
||||
/// time).
|
||||
text: String,
|
||||
|
||||
/// Height in terminal rows – matches the height of the textarea at the
|
||||
/// moment the task started so the UI does not jump when we toggle between
|
||||
/// input mode and loading mode.
|
||||
height: u16,
|
||||
|
||||
frame_idx: Arc<AtomicUsize>,
|
||||
running: Arc<AtomicBool>,
|
||||
// Keep one sender alive to prevent the channel from closing while the
|
||||
@@ -45,7 +50,7 @@ pub(crate) struct StatusIndicatorWidget {
|
||||
|
||||
impl StatusIndicatorWidget {
|
||||
/// Create a new status indicator and start the animation timer.
|
||||
pub(crate) fn new(app_event_tx: AppEventSender) -> Self {
|
||||
pub(crate) fn new(app_event_tx: AppEventSender, height: u16) -> Self {
|
||||
let frame_idx = Arc::new(AtomicUsize::new(0));
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
|
||||
@@ -67,12 +72,18 @@ impl StatusIndicatorWidget {
|
||||
|
||||
Self {
|
||||
text: String::from("waiting for logs…"),
|
||||
height: height.max(3),
|
||||
frame_idx,
|
||||
running,
|
||||
_app_event_tx: app_event_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Preferred height in terminal rows.
|
||||
pub(crate) fn get_height(&self) -> u16 {
|
||||
self.height
|
||||
}
|
||||
|
||||
/// Update the line that is displayed in the widget.
|
||||
pub(crate) fn update_text(&mut self, text: String) {
|
||||
self.text = text.replace(['\n', '\r'], " ");
|
||||
|
||||
@@ -4,39 +4,31 @@ use std::io::stdout;
|
||||
|
||||
use codex_core::config::Config;
|
||||
use crossterm::event::DisableBracketedPaste;
|
||||
use crossterm::event::DisableMouseCapture;
|
||||
use crossterm::event::EnableBracketedPaste;
|
||||
use ratatui::Terminal;
|
||||
use ratatui::TerminalOptions;
|
||||
use ratatui::Viewport;
|
||||
use ratatui::backend::CrosstermBackend;
|
||||
use ratatui::crossterm::execute;
|
||||
use ratatui::crossterm::terminal::EnterAlternateScreen;
|
||||
use ratatui::crossterm::terminal::LeaveAlternateScreen;
|
||||
use ratatui::crossterm::terminal::disable_raw_mode;
|
||||
use ratatui::crossterm::terminal::enable_raw_mode;
|
||||
|
||||
use crate::mouse_capture::MouseCapture;
|
||||
|
||||
/// A type alias for the terminal type used in this application
|
||||
pub type Tui = Terminal<CrosstermBackend<Stdout>>;
|
||||
|
||||
/// Initialize the terminal (inline viewport; history stays in normal scrollback)
|
||||
pub fn init(_config: &Config) -> Result<Tui> {
|
||||
/// Initialize the terminal
|
||||
pub fn init(config: &Config) -> Result<(Tui, MouseCapture)> {
|
||||
execute!(stdout(), EnterAlternateScreen)?;
|
||||
execute!(stdout(), EnableBracketedPaste)?;
|
||||
let mouse_capture = MouseCapture::new_with_capture(!config.tui.disable_mouse_capture)?;
|
||||
|
||||
enable_raw_mode()?;
|
||||
set_panic_hook();
|
||||
|
||||
// Reserve a fixed number of lines for the interactive viewport (composer,
|
||||
// status, popups). History is injected above using `insert_before`. This
|
||||
// is an initial step of the refactor – later the height can become
|
||||
// dynamic. For now a conservative default keeps enough room for the
|
||||
// multi‑line composer while not occupying the whole screen.
|
||||
const BOTTOM_VIEWPORT_HEIGHT: u16 = 8;
|
||||
let backend = CrosstermBackend::new(stdout());
|
||||
let tui = Terminal::with_options(
|
||||
backend,
|
||||
TerminalOptions {
|
||||
viewport: Viewport::Inline(BOTTOM_VIEWPORT_HEIGHT),
|
||||
},
|
||||
)?;
|
||||
Ok(tui)
|
||||
let tui = Terminal::new(CrosstermBackend::new(stdout()))?;
|
||||
Ok((tui, mouse_capture))
|
||||
}
|
||||
|
||||
fn set_panic_hook() {
|
||||
@@ -49,7 +41,14 @@ fn set_panic_hook() {
|
||||
|
||||
/// Restore the terminal to its original state
|
||||
pub fn restore() -> Result<()> {
|
||||
// We are shutting down, and we cannot reference the `MouseCapture`, so we
|
||||
// categorically disable mouse capture just to be safe.
|
||||
if execute!(stdout(), DisableMouseCapture).is_err() {
|
||||
// It is possible that `DisableMouseCapture` is written more than once
|
||||
// on shutdown, so ignore the error in this case.
|
||||
}
|
||||
execute!(stdout(), DisableBracketedPaste)?;
|
||||
execute!(stdout(), LeaveAlternateScreen)?;
|
||||
disable_raw_mode()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user