mirror of
https://github.com/openai/codex.git
synced 2026-06-02 11:22:01 +00:00
Compare commits
45 Commits
jif/state7
...
codex/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
795e4c91fb | ||
|
|
43b63ccae8 | ||
|
|
cc1b21e47f | ||
|
|
55801700de | ||
|
|
1fba99ed85 | ||
|
|
d3f6f6629b | ||
|
|
e555a36c6a | ||
|
|
ea095e30c1 | ||
|
|
c549481513 | ||
|
|
8797145678 | ||
|
|
a53720e278 | ||
|
|
41f5d61f24 | ||
|
|
02609184be | ||
|
|
1fc3413a46 | ||
|
|
eb2b739d6a | ||
|
|
a10403d697 | ||
|
|
8e3a048fec | ||
|
|
9f2ab97fbc | ||
|
|
38c9d7dca1 | ||
|
|
67aab04c66 | ||
|
|
7355ca48c5 | ||
|
|
affb5fc1d0 | ||
|
|
4a5f05c136 | ||
|
|
acc2b63dfb | ||
|
|
344d4a1d68 | ||
|
|
a0c37f5d07 | ||
|
|
103adcdf2d | ||
|
|
d61dea6fe6 | ||
|
|
e363dac249 | ||
|
|
250b244ab4 | ||
|
|
d1ed3a4cef | ||
|
|
e85742635f | ||
|
|
87b299aa3f | ||
|
|
0e58870634 | ||
|
|
42847baaf7 | ||
|
|
6032d784ee | ||
|
|
7bff8df10e | ||
|
|
addc946d13 | ||
|
|
bffdbec2c5 | ||
|
|
353a5c2046 | ||
|
|
00c7f7a16c | ||
|
|
82e65975b2 | ||
|
|
639a6fd2f3 | ||
|
|
db4aa6f916 | ||
|
|
cb96f4f596 |
2
.github/workflows/codespell.yml
vendored
2
.github/workflows/codespell.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Annotate locations with typos
|
||||
uses: codespell-project/codespell-problem-matcher@b80729f885d32f78a716c2f107b4db1025001c42 # v1
|
||||
- name: Codespell
|
||||
uses: codespell-project/actions-codespell@406322ec52dd7b488e48c1c4b82e2a8b3a1bf630 # v2
|
||||
uses: codespell-project/actions-codespell@406322ec52dd7b488e48c1c4b82e2a8b3a1bf630 # v2.1
|
||||
with:
|
||||
ignore_words_file: .codespellignore
|
||||
skip: frame*.txt
|
||||
|
||||
6
.github/workflows/rust-ci.yml
vendored
6
.github/workflows/rust-ci.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
working-directory: codex-rs
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: dtolnay/rust-toolchain@1.89
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
with:
|
||||
components: rustfmt
|
||||
- name: cargo fmt
|
||||
@@ -75,7 +75,7 @@ jobs:
|
||||
working-directory: codex-rs
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: dtolnay/rust-toolchain@1.89
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
- uses: taiki-e/install-action@0c5db7f7f897c03b771660e91d065338615679f4 # v2
|
||||
with:
|
||||
tool: cargo-shear
|
||||
@@ -143,7 +143,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: dtolnay/rust-toolchain@1.89
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
with:
|
||||
targets: ${{ matrix.target }}
|
||||
components: clippy
|
||||
|
||||
2
.github/workflows/rust-release.yml
vendored
2
.github/workflows/rust-release.yml
vendored
@@ -77,7 +77,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: dtolnay/rust-toolchain@1.89
|
||||
- uses: dtolnay/rust-toolchain@1.90
|
||||
with:
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
<h1 align="center">OpenAI Codex CLI</h1>
|
||||
|
||||
<p align="center"><code>npm i -g @openai/codex</code><br />or <code>brew install codex</code></p>
|
||||
|
||||
@@ -102,4 +101,3 @@ Codex CLI supports a rich set of configuration options, with preferences stored
|
||||
## License
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ CODEX_CLI_ROOT = SCRIPT_DIR.parent
|
||||
REPO_ROOT = CODEX_CLI_ROOT.parent
|
||||
GITHUB_REPO = "openai/codex"
|
||||
|
||||
# The docs are not clear on what the expected value/format of
|
||||
# workflow/workflowName is:
|
||||
# https://cli.github.com/manual/gh_run_list
|
||||
WORKFLOW_NAME = ".github/workflows/rust-release.yml"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Build or stage the Codex CLI npm package.")
|
||||
@@ -163,10 +168,8 @@ def install_native_binaries(staging_dir: Path, workflow_url: str | None) -> None
|
||||
|
||||
def resolve_latest_alpha_workflow_url() -> str:
|
||||
version = determine_latest_alpha_version()
|
||||
workflow_url = fetch_workflow_url_for_version(version)
|
||||
if not workflow_url:
|
||||
raise RuntimeError(f"Unable to locate workflow for version {version}.")
|
||||
return workflow_url
|
||||
workflow = resolve_release_workflow(version)
|
||||
return workflow["url"]
|
||||
|
||||
|
||||
def determine_latest_alpha_version() -> str:
|
||||
@@ -205,36 +208,6 @@ def list_releases() -> list[dict]:
|
||||
return releases
|
||||
|
||||
|
||||
def fetch_workflow_url_for_version(version: str) -> str | None:
|
||||
ref = f"rust-v{version}"
|
||||
stdout = subprocess.check_output(
|
||||
[
|
||||
"gh",
|
||||
"run",
|
||||
"list",
|
||||
"--branch",
|
||||
ref,
|
||||
"--limit",
|
||||
"20",
|
||||
"--json",
|
||||
"workflowName,url",
|
||||
],
|
||||
text=True,
|
||||
)
|
||||
|
||||
try:
|
||||
runs = json.loads(stdout or "[]")
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError("Unable to parse workflow run listing.") from exc
|
||||
|
||||
for run in runs:
|
||||
if run.get("workflowName") == "rust-release":
|
||||
url = run.get("url")
|
||||
if url:
|
||||
return url
|
||||
return None
|
||||
|
||||
|
||||
def resolve_release_workflow(version: str) -> dict:
|
||||
stdout = subprocess.check_output(
|
||||
[
|
||||
@@ -245,12 +218,14 @@ def resolve_release_workflow(version: str) -> dict:
|
||||
f"rust-v{version}",
|
||||
"--json",
|
||||
"workflowName,url,headSha",
|
||||
"--workflow",
|
||||
WORKFLOW_NAME,
|
||||
"--jq",
|
||||
'first(.[] | select(.workflowName == "rust-release"))',
|
||||
"first(.[])",
|
||||
],
|
||||
text=True,
|
||||
)
|
||||
workflow = json.loads(stdout)
|
||||
workflow = json.loads(stdout or "[]")
|
||||
if not workflow:
|
||||
raise RuntimeError(f"Unable to find rust-release workflow for version {version}.")
|
||||
return workflow
|
||||
|
||||
417
codex-rs/Cargo.lock
generated
417
codex-rs/Cargo.lock
generated
@@ -56,7 +56,7 @@ checksum = "8fac2ce611db8b8cee9b2aa886ca03c924e9da5e5295d0dbd0526e5d0b0710f7"
|
||||
dependencies = [
|
||||
"allocative_derive",
|
||||
"bumpalo",
|
||||
"ctor",
|
||||
"ctor 0.1.26",
|
||||
"hashbrown 0.14.5",
|
||||
"num-bigint",
|
||||
]
|
||||
@@ -78,12 +78,6 @@ version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
@@ -495,18 +489,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.41"
|
||||
name = "cfg_aliases"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
|
||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.42"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"windows-link",
|
||||
"windows-link 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -595,7 +594,6 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"once_cell",
|
||||
"pretty_assertions",
|
||||
"similar",
|
||||
"tempfile",
|
||||
@@ -648,7 +646,10 @@ dependencies = [
|
||||
"codex-mcp-server",
|
||||
"codex-protocol",
|
||||
"codex-protocol-ts",
|
||||
"codex-responses-api-proxy",
|
||||
"codex-tui",
|
||||
"ctor 0.5.0",
|
||||
"libc",
|
||||
"owo-colors",
|
||||
"predicates",
|
||||
"pretty_assertions",
|
||||
@@ -679,6 +680,7 @@ dependencies = [
|
||||
"askama",
|
||||
"assert_cmd",
|
||||
"async-channel",
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -686,11 +688,14 @@ dependencies = [
|
||||
"codex-file-search",
|
||||
"codex-mcp-client",
|
||||
"codex-protocol",
|
||||
"codex-rmcp-client",
|
||||
"core_test_support",
|
||||
"dirs",
|
||||
"env-flags",
|
||||
"escargot",
|
||||
"eventsource-stream",
|
||||
"futures",
|
||||
"indexmap 2.10.0",
|
||||
"landlock",
|
||||
"libc",
|
||||
"maplit",
|
||||
@@ -745,12 +750,15 @@ dependencies = [
|
||||
"libc",
|
||||
"owo-colors",
|
||||
"predicates",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"ts-rs",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
"wiremock",
|
||||
@@ -927,6 +935,36 @@ dependencies = [
|
||||
"ts-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-responses-api-proxy"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"codex-arg0",
|
||||
"libc",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tiny_http",
|
||||
"tokio",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-rmcp-client"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"mcp-types",
|
||||
"pretty_assertions",
|
||||
"rmcp",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-tui"
|
||||
version = "0.0.0"
|
||||
@@ -956,7 +994,6 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"mcp-types",
|
||||
"once_cell",
|
||||
"path-clean",
|
||||
"pathdiff",
|
||||
"pretty_assertions",
|
||||
@@ -978,7 +1015,7 @@ dependencies = [
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
"unicode-segmentation",
|
||||
"unicode-width 0.1.14",
|
||||
"unicode-width 0.2.1",
|
||||
"url",
|
||||
"vt100",
|
||||
]
|
||||
@@ -1111,6 +1148,7 @@ name = "core_test_support"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-core",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
@@ -1222,14 +1260,40 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctor"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67773048316103656a637612c4a62477603b777d91d9c62ff2290f9cde178fdb"
|
||||
dependencies = [
|
||||
"ctor-proc-macro",
|
||||
"dtor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctor-proc-macro"
|
||||
version = "0.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2931af7e13dc045d8e9d26afccc6fa115d64e115c9c84b1166288b46f6782c2"
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
"darling_core 0.20.11",
|
||||
"darling_macro 0.20.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0"
|
||||
dependencies = [
|
||||
"darling_core 0.21.3",
|
||||
"darling_macro 0.21.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1246,13 +1310,38 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim 0.11.1",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_core 0.20.11",
|
||||
"quote",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81"
|
||||
dependencies = [
|
||||
"darling_core 0.21.3",
|
||||
"quote",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
@@ -1472,6 +1561,21 @@ version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2"
|
||||
|
||||
[[package]]
|
||||
name = "dtor"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e58a0764cddb55ab28955347b45be00ade43d4d6f3ba4bf3dc354e4ec9432934"
|
||||
dependencies = [
|
||||
"dtor-proc-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dtor-proc-macro"
|
||||
version = "0.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5"
|
||||
|
||||
[[package]]
|
||||
name = "dupe"
|
||||
version = "0.9.1"
|
||||
@@ -1614,6 +1718,17 @@ version = "3.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59"
|
||||
|
||||
[[package]]
|
||||
name = "escargot"
|
||||
version = "0.5.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11c3aea32bc97b500c9ca6a72b768a26e558264303d101d3409cf6d57a9ed0cf"
|
||||
dependencies = [
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
version = "5.4.0"
|
||||
@@ -1919,8 +2034,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi 0.11.1+wasi-snapshot-preview1",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1930,9 +2047,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasi 0.14.2+wasi-0.2.4",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2129,6 +2248,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower-service",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2438,7 +2558,7 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "435d80800b936787d62688c927b6490e887c7ef5ff9ce922c6c6050fca75eb9a"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"darling 0.20.11",
|
||||
"indoc",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2729,6 +2849,12 @@ dependencies = [
|
||||
"hashbrown 0.15.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-slab"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||
|
||||
[[package]]
|
||||
name = "lsp-types"
|
||||
version = "0.94.1"
|
||||
@@ -2908,7 +3034,19 @@ checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"cfg-if",
|
||||
"cfg_aliases",
|
||||
"cfg_aliases 0.1.1",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nix"
|
||||
version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6"
|
||||
dependencies = [
|
||||
"bitflags 2.9.1",
|
||||
"cfg-if",
|
||||
"cfg_aliases 0.2.1",
|
||||
"libc",
|
||||
]
|
||||
|
||||
@@ -3334,7 +3472,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"log",
|
||||
"nix",
|
||||
"nix 0.28.0",
|
||||
"serial2",
|
||||
"shared_library",
|
||||
"shell-words",
|
||||
@@ -3422,6 +3560,20 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "process-wrap"
|
||||
version = "8.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3ef4f2f0422f23a82ec9f628ea2acd12871c81a9362b02c43c1aa86acfc3ba1"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"indexmap 2.10.0",
|
||||
"nix 0.30.1",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"windows",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pulldown-cmark"
|
||||
version = "0.10.3"
|
||||
@@ -3465,6 +3617,61 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"cfg_aliases 0.2.1",
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustls",
|
||||
"socket2",
|
||||
"thiserror 2.0.16",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"getrandom 0.3.3",
|
||||
"lru-slab",
|
||||
"rand",
|
||||
"ring",
|
||||
"rustc-hash",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
"thiserror 2.0.16",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
|
||||
dependencies = [
|
||||
"cfg_aliases 0.2.1",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2",
|
||||
"tracing",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.40"
|
||||
@@ -3657,6 +3864,8 @@ dependencies = [
|
||||
"native-tls",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"quinn",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -3664,6 +3873,7 @@ dependencies = [
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower-http",
|
||||
@@ -3673,6 +3883,7 @@ dependencies = [
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3689,12 +3900,54 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rmcp"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "534fd1cd0601e798ac30545ff2b7f4a62c6f14edd4aaed1cc5eb1e85f69f09af"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"chrono",
|
||||
"futures",
|
||||
"paste",
|
||||
"pin-project-lite",
|
||||
"process-wrap",
|
||||
"rmcp-macros",
|
||||
"schemars 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.16",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rmcp-macros"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ba777eb0e5f53a757e36f0e287441da0ab766564ba7201600eeb92a4753022e"
|
||||
dependencies = [
|
||||
"darling 0.21.3",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde_json",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
@@ -3728,6 +3981,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"subtle",
|
||||
@@ -3740,6 +3994,7 @@ version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79"
|
||||
dependencies = [
|
||||
"web-time",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
@@ -3774,7 +4029,7 @@ dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"memchr",
|
||||
"nix",
|
||||
"nix 0.28.0",
|
||||
"radix_trie",
|
||||
"unicode-segmentation",
|
||||
"unicode-width 0.1.14",
|
||||
@@ -3855,7 +4110,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
|
||||
dependencies = [
|
||||
"dyn-clone",
|
||||
"schemars_derive",
|
||||
"schemars_derive 0.8.22",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
@@ -3878,8 +4133,10 @@ version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"dyn-clone",
|
||||
"ref-cast",
|
||||
"schemars_derive 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
@@ -3896,6 +4153,18 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schemars_derive"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33d020396d1d138dc19f1165df7545479dcd58d93810dc5d646a16e55abefa80"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde_derive_internals",
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
@@ -4047,7 +4316,7 @@ version = "3.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"darling 0.20.11",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.104",
|
||||
@@ -4654,6 +4923,21 @@ dependencies = [
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
|
||||
dependencies = [
|
||||
"tinyvec_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec_macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.47.1"
|
||||
@@ -5266,6 +5550,16 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-time"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webbrowser"
|
||||
version = "1.0.5"
|
||||
@@ -5282,6 +5576,15 @@ dependencies = [
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e8983c3ab33d6fb807cfcdad2491c4ea8cbc8ed839181c7dfd9c67c83e261b2"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weezl"
|
||||
version = "0.1.10"
|
||||
@@ -5337,6 +5640,28 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows"
|
||||
version = "0.61.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893"
|
||||
dependencies = [
|
||||
"windows-collections",
|
||||
"windows-core",
|
||||
"windows-future",
|
||||
"windows-link 0.1.3",
|
||||
"windows-numerics",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-collections"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8"
|
||||
dependencies = [
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.61.2"
|
||||
@@ -5345,11 +5670,22 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3"
|
||||
dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-link",
|
||||
"windows-link 0.1.3",
|
||||
"windows-result",
|
||||
"windows-strings",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-future"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e"
|
||||
dependencies = [
|
||||
"windows-core",
|
||||
"windows-link 0.1.3",
|
||||
"windows-threading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-implement"
|
||||
version = "0.60.0"
|
||||
@@ -5378,13 +5714,29 @@ version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a"
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65"
|
||||
|
||||
[[package]]
|
||||
name = "windows-numerics"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1"
|
||||
dependencies = [
|
||||
"windows-core",
|
||||
"windows-link 0.1.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-registry"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
"windows-link 0.1.3",
|
||||
"windows-result",
|
||||
"windows-strings",
|
||||
]
|
||||
@@ -5395,7 +5747,7 @@ version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
"windows-link 0.1.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5404,7 +5756,7 @@ version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
"windows-link 0.1.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5505,6 +5857,15 @@ dependencies = [
|
||||
"windows_x86_64_msvc 0.53.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-threading"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6"
|
||||
dependencies = [
|
||||
"windows-link 0.1.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.2"
|
||||
|
||||
@@ -18,6 +18,8 @@ members = [
|
||||
"ollama",
|
||||
"protocol",
|
||||
"protocol-ts",
|
||||
"rmcp-client",
|
||||
"responses-api-proxy",
|
||||
"tui",
|
||||
"utils/readiness",
|
||||
]
|
||||
@@ -48,7 +50,9 @@ codex-mcp-client = { path = "mcp-client" }
|
||||
codex-mcp-server = { path = "mcp-server" }
|
||||
codex-ollama = { path = "ollama" }
|
||||
codex-protocol = { path = "protocol" }
|
||||
codex-rmcp-client = { path = "rmcp-client" }
|
||||
codex-protocol-ts = { path = "protocol-ts" }
|
||||
codex-responses-api-proxy = { path = "responses-api-proxy" }
|
||||
codex-tui = { path = "tui" }
|
||||
codex-utils-readiness = { path = "utils/readiness" }
|
||||
core_test_support = { path = "core/tests/common" }
|
||||
@@ -67,11 +71,12 @@ async-stream = "0.3.6"
|
||||
async-trait = "0.1.89"
|
||||
base64 = "0.22.1"
|
||||
bytes = "1.10.1"
|
||||
chrono = "0.4.40"
|
||||
chrono = "0.4.42"
|
||||
clap = "4"
|
||||
clap_complete = "4"
|
||||
color-eyre = "0.6.3"
|
||||
crossterm = "0.28.1"
|
||||
ctor = "0.5.0"
|
||||
derive_more = "2"
|
||||
diffy = "0.4.2"
|
||||
dirs = "6"
|
||||
@@ -79,11 +84,13 @@ dotenvy = "0.15.7"
|
||||
env-flags = "0.1.1"
|
||||
env_logger = "0.11.5"
|
||||
eventsource-stream = "0.2.3"
|
||||
escargot = "0.5"
|
||||
futures = "0.3"
|
||||
icu_decimal = "2.0.0"
|
||||
icu_locale_core = "2.0.0"
|
||||
ignore = "0.4.23"
|
||||
image = { version = "^0.25.8", default-features = false }
|
||||
indexmap = "2.6.0"
|
||||
insta = "1.43.2"
|
||||
itertools = "0.14.0"
|
||||
landlock = "0.4.1"
|
||||
@@ -94,7 +101,6 @@ maplit = "1.0.2"
|
||||
mime_guess = "2.0.5"
|
||||
multimap = "0.10.0"
|
||||
nucleo-matcher = "0.3.1"
|
||||
once_cell = "1"
|
||||
openssl-sys = "*"
|
||||
os_info = "3.12.0"
|
||||
owo-colors = "4.2.0"
|
||||
@@ -141,7 +147,7 @@ tree-sitter = "0.25.9"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
ts-rs = "11"
|
||||
unicode-segmentation = "1.12.0"
|
||||
unicode-width = "0.1"
|
||||
unicode-width = "0.2"
|
||||
url = "2"
|
||||
urlencoding = "2.1"
|
||||
uuid = "1"
|
||||
@@ -151,6 +157,7 @@ webbrowser = "1.0"
|
||||
which = "6"
|
||||
wildmatch = "2.5.0"
|
||||
wiremock = "0.6"
|
||||
zeroize = "1.8.1"
|
||||
|
||||
[workspace.lints]
|
||||
rust = {}
|
||||
|
||||
@@ -20,7 +20,6 @@ similar = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tree-sitter = { workspace = true }
|
||||
tree-sitter-bash = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
|
||||
@@ -6,10 +6,10 @@ use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::str::Utf8Error;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use once_cell::sync::Lazy;
|
||||
pub use parser::Hunk;
|
||||
pub use parser::ParseError;
|
||||
use parser::ParseError::*;
|
||||
@@ -351,7 +351,7 @@ fn extract_apply_patch_from_bash(
|
||||
// also run an arbitrary query against the AST. This is useful for understanding
|
||||
// how tree-sitter parses the script and whether the query syntax is correct. Be sure
|
||||
// to test both positive and negative cases.
|
||||
static APPLY_PATCH_QUERY: Lazy<Query> = Lazy::new(|| {
|
||||
static APPLY_PATCH_QUERY: LazyLock<Query> = LazyLock::new(|| {
|
||||
let language = BASH.into();
|
||||
#[expect(clippy::expect_used)]
|
||||
Query::new(
|
||||
|
||||
@@ -27,7 +27,9 @@ codex-login = { workspace = true }
|
||||
codex-mcp-server = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-protocol-ts = { workspace = true }
|
||||
codex-responses-api-proxy = { workspace = true }
|
||||
codex-tui = { workspace = true }
|
||||
ctor = { workspace = true }
|
||||
owo-colors = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
supports-color = { workspace = true }
|
||||
@@ -41,6 +43,15 @@ tokio = { workspace = true, features = [
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
libc = { workspace = true }
|
||||
|
||||
[target.'cfg(target_os = "android")'.dependencies]
|
||||
libc = { workspace = true }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
libc = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use anyhow::Context;
|
||||
use clap::CommandFactory;
|
||||
use clap::Parser;
|
||||
use clap_complete::Shell;
|
||||
@@ -14,6 +15,7 @@ use codex_cli::login::run_logout;
|
||||
use codex_cli::proto;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_exec::Cli as ExecCli;
|
||||
use codex_responses_api_proxy::Args as ResponsesApiProxyArgs;
|
||||
use codex_tui::AppExitInfo;
|
||||
use codex_tui::Cli as TuiCli;
|
||||
use owo_colors::OwoColorize;
|
||||
@@ -21,6 +23,7 @@ use std::path::PathBuf;
|
||||
use supports_color::Stream;
|
||||
|
||||
mod mcp_cmd;
|
||||
mod pre_main_hardening;
|
||||
|
||||
use crate::mcp_cmd::McpCli;
|
||||
use crate::proto::ProtoCli;
|
||||
@@ -85,6 +88,10 @@ enum Subcommand {
|
||||
/// Internal: generate TypeScript protocol bindings.
|
||||
#[clap(hide = true)]
|
||||
GenerateTs(GenerateTsCommand),
|
||||
|
||||
/// Internal: run the responses API proxy.
|
||||
#[clap(hide = true)]
|
||||
ResponsesApiProxy(ResponsesApiProxyArgs),
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
@@ -194,6 +201,34 @@ fn print_exit_messages(exit_info: AppExitInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const CODEX_SECURE_MODE_ENV_VAR: &str = "CODEX_SECURE_MODE";
|
||||
|
||||
/// As early as possible in the process lifecycle, apply hardening measures
|
||||
/// if the CODEX_SECURE_MODE environment variable is set to "1".
|
||||
#[ctor::ctor]
|
||||
fn pre_main_hardening() {
|
||||
let secure_mode = match std::env::var(CODEX_SECURE_MODE_ENV_VAR) {
|
||||
Ok(value) => value,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
if secure_mode == "1" {
|
||||
#[cfg(any(target_os = "linux", target_os = "android"))]
|
||||
crate::pre_main_hardening::pre_main_hardening_linux();
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
crate::pre_main_hardening::pre_main_hardening_macos();
|
||||
|
||||
#[cfg(windows)]
|
||||
crate::pre_main_hardening::pre_main_hardening_windows();
|
||||
}
|
||||
|
||||
// Always clear this env var so child processes don't inherit it.
|
||||
unsafe {
|
||||
std::env::remove_var(CODEX_SECURE_MODE_ENV_VAR);
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
cli_main(codex_linux_sandbox_exe).await?;
|
||||
@@ -312,6 +347,11 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
Some(Subcommand::GenerateTs(gen_cli)) => {
|
||||
codex_protocol_ts::generate_ts(&gen_cli.out_dir, gen_cli.prettier.as_deref())?;
|
||||
}
|
||||
Some(Subcommand::ResponsesApiProxy(args)) => {
|
||||
tokio::task::spawn_blocking(move || codex_responses_api_proxy::run_main(args))
|
||||
.await
|
||||
.context("responses-api-proxy blocking task panicked")??;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
98
codex-rs/cli/src/pre_main_hardening.rs
Normal file
98
codex-rs/cli/src/pre_main_hardening.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
#[cfg(any(target_os = "linux", target_os = "android"))]
|
||||
const PRCTL_FAILED_EXIT_CODE: i32 = 5;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
const PTRACE_DENY_ATTACH_FAILED_EXIT_CODE: i32 = 6;
|
||||
|
||||
#[cfg(any(target_os = "linux", target_os = "android", target_os = "macos"))]
|
||||
const SET_RLIMIT_CORE_FAILED_EXIT_CODE: i32 = 7;
|
||||
|
||||
#[cfg(any(target_os = "linux", target_os = "android"))]
|
||||
pub(crate) fn pre_main_hardening_linux() {
|
||||
// Disable ptrace attach / mark process non-dumpable.
|
||||
let ret_code = unsafe { libc::prctl(libc::PR_SET_DUMPABLE, 0, 0, 0, 0) };
|
||||
if ret_code != 0 {
|
||||
eprintln!(
|
||||
"ERROR: prctl(PR_SET_DUMPABLE, 0) failed: {}",
|
||||
std::io::Error::last_os_error()
|
||||
);
|
||||
std::process::exit(PRCTL_FAILED_EXIT_CODE);
|
||||
}
|
||||
|
||||
// For "defense in depth," set the core file size limit to 0.
|
||||
set_core_file_size_limit_to_zero();
|
||||
|
||||
// Official Codex releases are MUSL-linked, which means that variables such
|
||||
// as LD_PRELOAD are ignored anyway, but just to be sure, clear them here.
|
||||
let ld_keys: Vec<String> = std::env::vars()
|
||||
.filter_map(|(key, _)| {
|
||||
if key.starts_with("LD_") {
|
||||
Some(key)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for key in ld_keys {
|
||||
unsafe {
|
||||
std::env::remove_var(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub(crate) fn pre_main_hardening_macos() {
|
||||
// Prevent debuggers from attaching to this process.
|
||||
let ret_code = unsafe { libc::ptrace(libc::PT_DENY_ATTACH, 0, std::ptr::null_mut(), 0) };
|
||||
if ret_code == -1 {
|
||||
eprintln!(
|
||||
"ERROR: ptrace(PT_DENY_ATTACH) failed: {}",
|
||||
std::io::Error::last_os_error()
|
||||
);
|
||||
std::process::exit(PTRACE_DENY_ATTACH_FAILED_EXIT_CODE);
|
||||
}
|
||||
|
||||
// Set the core file size limit to 0 to prevent core dumps.
|
||||
set_core_file_size_limit_to_zero();
|
||||
|
||||
// Remove all DYLD_ environment variables, which can be used to subvert
|
||||
// library loading.
|
||||
let dyld_keys: Vec<String> = std::env::vars()
|
||||
.filter_map(|(key, _)| {
|
||||
if key.starts_with("DYLD_") {
|
||||
Some(key)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for key in dyld_keys {
|
||||
unsafe {
|
||||
std::env::remove_var(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn set_core_file_size_limit_to_zero() {
|
||||
let rlim = libc::rlimit {
|
||||
rlim_cur: 0,
|
||||
rlim_max: 0,
|
||||
};
|
||||
|
||||
let ret_code = unsafe { libc::setrlimit(libc::RLIMIT_CORE, &rlim) };
|
||||
if ret_code != 0 {
|
||||
eprintln!(
|
||||
"ERROR: setrlimit(RLIMIT_CORE) failed: {}",
|
||||
std::io::Error::last_os_error()
|
||||
);
|
||||
std::process::exit(SET_RLIMIT_CORE_FAILED_EXIT_CODE);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
pub(crate) fn pre_main_hardening_windows() {
|
||||
// TODO(mbolin): Perform the appropriate configuration for Windows.
|
||||
}
|
||||
0
codex-rs/code
Normal file
0
codex-rs/code
Normal file
@@ -15,17 +15,20 @@ workspace = true
|
||||
anyhow = { workspace = true }
|
||||
askama = { workspace = true }
|
||||
async-channel = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
codex-apply-patch = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-mcp-client = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
env-flags = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
mcp-types = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
@@ -80,6 +83,7 @@ openssl-sys = { workspace = true, features = ["vendored"] }
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
escargot = { workspace = true }
|
||||
maplit = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::FileChange;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::safety::assess_patch_safety;
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::ApplyPatchFileChange;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -17,7 +16,7 @@ pub(crate) enum InternalApplyPatchInvocation {
|
||||
/// The `apply_patch` call was handled programmatically, without any sort
|
||||
/// of sandbox, because the user explicitly approved it. This is the
|
||||
/// result to use with the `shell` function call that contained `apply_patch`.
|
||||
Output(ResponseInputItem),
|
||||
Output(Result<String, FunctionCallError>),
|
||||
|
||||
/// The `apply_patch` call was approved, either automatically because it
|
||||
/// appears that it should be allowed based on the user's sandbox policy
|
||||
@@ -33,12 +32,6 @@ pub(crate) struct ApplyPatchExec {
|
||||
pub(crate) user_explicitly_approved_this_action: bool,
|
||||
}
|
||||
|
||||
impl From<ResponseInputItem> for InternalApplyPatchInvocation {
|
||||
fn from(item: ResponseInputItem) -> Self {
|
||||
InternalApplyPatchInvocation::Output(item)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn apply_patch(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
@@ -77,25 +70,15 @@ pub(crate) async fn apply_patch(
|
||||
})
|
||||
}
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.to_owned(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "patch rejected by user".to_string(),
|
||||
success: Some(false),
|
||||
},
|
||||
}
|
||||
.into()
|
||||
InternalApplyPatchInvocation::Output(Err(FunctionCallError::RespondToModel(
|
||||
"patch rejected by user".to_string(),
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
SafetyCheck::Reject { reason } => ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.to_owned(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("patch rejected: {reason}"),
|
||||
success: Some(false),
|
||||
},
|
||||
}
|
||||
.into(),
|
||||
SafetyCheck::Reject { reason } => InternalApplyPatchInvocation::Output(Err(
|
||||
FunctionCallError::RespondToModel(format!("patch rejected: {reason}")),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -267,6 +267,9 @@ pub fn try_read_auth_json(auth_file: &Path) -> std::io::Result<AuthDotJson> {
|
||||
}
|
||||
|
||||
pub fn write_auth_json(auth_file: &Path, auth_dot_json: &AuthDotJson) -> std::io::Result<()> {
|
||||
if let Some(parent) = auth_file.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let json_data = serde_json::to_string_pretty(auth_dot_json)?;
|
||||
let mut options = OpenOptions::new();
|
||||
options.truncate(true).write(true).create(true);
|
||||
|
||||
@@ -88,6 +88,21 @@ pub fn try_parse_word_only_commands_sequence(tree: &Tree, src: &str) -> Option<V
|
||||
Some(commands)
|
||||
}
|
||||
|
||||
/// Returns the sequence of plain commands within a `bash -lc "..."` invocation
|
||||
/// when the script only contains word-only commands joined by safe operators.
|
||||
pub fn parse_bash_lc_plain_commands(command: &[String]) -> Option<Vec<Vec<String>>> {
|
||||
let [bash, flag, script] = command else {
|
||||
return None;
|
||||
};
|
||||
|
||||
if bash != "bash" || flag != "-lc" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tree = try_parse_bash(script)?;
|
||||
try_parse_word_only_commands_sequence(&tree, script)
|
||||
}
|
||||
|
||||
fn parse_plain_command_from_node(cmd: tree_sitter::Node, src: &str) -> Option<Vec<String>> {
|
||||
if cmd.kind() != "command" {
|
||||
return None;
|
||||
|
||||
@@ -43,6 +43,7 @@ use crate::model_provider_info::WireApi;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::RateLimitWindow;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::util::backoff;
|
||||
@@ -183,19 +184,23 @@ impl ModelClient {
|
||||
|
||||
let input_with_instructions = prompt.get_formatted_input();
|
||||
|
||||
// Only include `text.verbosity` for GPT-5 family models
|
||||
let text = if self.config.model_family.family == "gpt-5" {
|
||||
create_text_param_for_request(self.config.model_verbosity, &prompt.output_schema)
|
||||
} else {
|
||||
if self.config.model_verbosity.is_some() {
|
||||
warn!(
|
||||
"model_verbosity is set but ignored for non-gpt-5 model family: {}",
|
||||
self.config.model_family.family
|
||||
);
|
||||
let verbosity = match &self.config.model_family.family {
|
||||
family if family == "gpt-5" => self.config.model_verbosity,
|
||||
_ => {
|
||||
if self.config.model_verbosity.is_some() {
|
||||
warn!(
|
||||
"model_verbosity is set but ignored for non-gpt-5 model family: {}",
|
||||
self.config.model_family.family
|
||||
);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
// Only include `text.verbosity` for GPT-5 family models
|
||||
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
|
||||
|
||||
// In general, we want to explicitly send `store: false` when using the Responses API,
|
||||
// but in practice, the Azure Responses API rejects `store: false`:
|
||||
//
|
||||
@@ -224,155 +229,169 @@ impl ModelClient {
|
||||
if azure_workaround {
|
||||
attach_item_ids(&mut payload_json, &input_with_instructions);
|
||||
}
|
||||
let payload_body = serde_json::to_string(&payload_json)?;
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = self.provider.request_max_retries();
|
||||
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
// Always fetch the latest auth in case a prior attempt refreshed the token.
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
|
||||
trace!(
|
||||
"POST to {}: {}",
|
||||
self.provider.get_full_url(&auth),
|
||||
payload_body.as_str()
|
||||
);
|
||||
|
||||
let mut req_builder = self
|
||||
.provider
|
||||
.create_request_builder(&self.client, &auth)
|
||||
.await?;
|
||||
|
||||
req_builder = req_builder
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
// Send session_id for compatibility.
|
||||
.header("conversation_id", self.conversation_id.to_string())
|
||||
.header("session_id", self.conversation_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload_json);
|
||||
|
||||
if let Some(auth) = auth.as_ref()
|
||||
&& auth.mode == AuthMode::ChatGPT
|
||||
&& let Some(account_id) = auth.get_account_id()
|
||||
let max_attempts = self.provider.request_max_retries();
|
||||
for attempt in 0..=max_attempts {
|
||||
match self
|
||||
.attempt_stream_responses(&payload_json, &auth_manager)
|
||||
.await
|
||||
{
|
||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
"Response status: {}, cf-ray: {}",
|
||||
resp.status(),
|
||||
resp.headers()
|
||||
.get("cf-ray")
|
||||
.map(|v| v.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
|
||||
if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers())
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::RateLimits(snapshot)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
debug!("receiver dropped rate limit snapshot event");
|
||||
}
|
||||
|
||||
// spawn task to process SSE
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
self.provider.stream_idle_timeout(),
|
||||
));
|
||||
|
||||
return Ok(ResponseStream { rx_event });
|
||||
Ok(stream) => {
|
||||
return Ok(stream);
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
|
||||
// Pull out Retry‑After header if present.
|
||||
let retry_after_secs = res
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
|
||||
if status == StatusCode::UNAUTHORIZED
|
||||
&& let Some(manager) = auth_manager.as_ref()
|
||||
&& manager.auth().is_some()
|
||||
{
|
||||
let _ = manager.refresh_token().await;
|
||||
Err(StreamAttemptError::Fatal(e)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Err(retryable_attempt_error) => {
|
||||
if attempt == max_attempts {
|
||||
return Err(retryable_attempt_error.into_error());
|
||||
}
|
||||
|
||||
// The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx
|
||||
// errors. When we bubble early with only the HTTP status the caller sees an opaque
|
||||
// "unexpected status 400 Bad Request" which makes debugging nearly impossible.
|
||||
// Instead, read (and include) the response text so higher layers and users see the
|
||||
// exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is
|
||||
// small and this branch only runs on error paths so the extra allocation is
|
||||
// negligible.
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS
|
||||
|| status == StatusCode::UNAUTHORIZED
|
||||
|| status.is_server_error())
|
||||
{
|
||||
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
|
||||
let body = res.text().await.unwrap_or_default();
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
tokio::time::sleep(retryable_attempt_error.delay(attempt)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if status == StatusCode::TOO_MANY_REQUESTS {
|
||||
let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers());
|
||||
let body = res.json::<ErrorResponse>().await.ok();
|
||||
if let Some(ErrorResponse { error }) = body {
|
||||
if error.r#type.as_deref() == Some("usage_limit_reached") {
|
||||
// Prefer the plan_type provided in the error message if present
|
||||
// because it's more up to date than the one encoded in the auth
|
||||
// token.
|
||||
let plan_type = error
|
||||
.plan_type
|
||||
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
|
||||
let resets_in_seconds = error.resets_in_seconds;
|
||||
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
plan_type,
|
||||
resets_in_seconds,
|
||||
rate_limits: rate_limit_snapshot,
|
||||
}));
|
||||
} else if error.r#type.as_deref() == Some("usage_not_included") {
|
||||
return Err(CodexErr::UsageNotIncluded);
|
||||
}
|
||||
unreachable!("stream_responses_attempt should always return");
|
||||
}
|
||||
|
||||
/// Single attempt to start a streaming Responses API call.
|
||||
async fn attempt_stream_responses(
|
||||
&self,
|
||||
payload_json: &Value,
|
||||
auth_manager: &Option<Arc<AuthManager>>,
|
||||
) -> std::result::Result<ResponseStream, StreamAttemptError> {
|
||||
// Always fetch the latest auth in case a prior attempt refreshed the token.
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
|
||||
trace!(
|
||||
"POST to {}: {:?}",
|
||||
self.provider.get_full_url(&auth),
|
||||
serde_json::to_string(payload_json)
|
||||
);
|
||||
|
||||
let mut req_builder = self
|
||||
.provider
|
||||
.create_request_builder(&self.client, &auth)
|
||||
.await
|
||||
.map_err(StreamAttemptError::Fatal)?;
|
||||
|
||||
req_builder = req_builder
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
// Send session_id for compatibility.
|
||||
.header("conversation_id", self.conversation_id.to_string())
|
||||
.header("session_id", self.conversation_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(payload_json);
|
||||
|
||||
if let Some(auth) = auth.as_ref()
|
||||
&& auth.mode == AuthMode::ChatGPT
|
||||
&& let Some(account_id) = auth.get_account_id()
|
||||
{
|
||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
"Response status: {}, cf-ray: {}",
|
||||
resp.status(),
|
||||
resp.headers()
|
||||
.get("cf-ray")
|
||||
.map(|v| v.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
|
||||
if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers())
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::RateLimits(snapshot)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
debug!("receiver dropped rate limit snapshot event");
|
||||
}
|
||||
|
||||
// spawn task to process SSE
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
self.provider.stream_idle_timeout(),
|
||||
));
|
||||
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
|
||||
// Pull out Retry‑After header if present.
|
||||
let retry_after_secs = res
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000));
|
||||
|
||||
if status == StatusCode::UNAUTHORIZED
|
||||
&& let Some(manager) = auth_manager.as_ref()
|
||||
&& manager.auth().is_some()
|
||||
{
|
||||
let _ = manager.refresh_token().await;
|
||||
}
|
||||
|
||||
// The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx
|
||||
// errors. When we bubble early with only the HTTP status the caller sees an opaque
|
||||
// "unexpected status 400 Bad Request" which makes debugging nearly impossible.
|
||||
// Instead, read (and include) the response text so higher layers and users see the
|
||||
// exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is
|
||||
// small and this branch only runs on error paths so the extra allocation is
|
||||
// negligible.
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS
|
||||
|| status == StatusCode::UNAUTHORIZED
|
||||
|| status.is_server_error())
|
||||
{
|
||||
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
|
||||
let body = res.text().await.unwrap_or_default();
|
||||
return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus(
|
||||
status, body,
|
||||
)));
|
||||
}
|
||||
|
||||
if status == StatusCode::TOO_MANY_REQUESTS {
|
||||
let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers());
|
||||
let body = res.json::<ErrorResponse>().await.ok();
|
||||
if let Some(ErrorResponse { error }) = body {
|
||||
if error.r#type.as_deref() == Some("usage_limit_reached") {
|
||||
// Prefer the plan_type provided in the error message if present
|
||||
// because it's more up to date than the one encoded in the auth
|
||||
// token.
|
||||
let plan_type = error
|
||||
.plan_type
|
||||
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
|
||||
let resets_in_seconds = error.resets_in_seconds;
|
||||
let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
plan_type,
|
||||
resets_in_seconds,
|
||||
rate_limits: rate_limit_snapshot,
|
||||
});
|
||||
return Err(StreamAttemptError::Fatal(codex_err));
|
||||
} else if error.r#type.as_deref() == Some("usage_not_included") {
|
||||
return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded));
|
||||
}
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
if status == StatusCode::INTERNAL_SERVER_ERROR {
|
||||
return Err(CodexErr::InternalServerError);
|
||||
}
|
||||
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
let delay = retry_after_secs
|
||||
.map(|s| Duration::from_millis(s * 1_000))
|
||||
.unwrap_or_else(|| backoff(attempt));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
|
||||
Err(StreamAttemptError::RetryableHttpError {
|
||||
status,
|
||||
retry_after,
|
||||
})
|
||||
}
|
||||
Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,6 +424,47 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamAttemptError {
|
||||
RetryableHttpError {
|
||||
status: StatusCode,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
RetryableTransportError(CodexErr),
|
||||
Fatal(CodexErr),
|
||||
}
|
||||
|
||||
impl StreamAttemptError {
|
||||
/// attempt is 0-based.
|
||||
fn delay(&self, attempt: u64) -> Duration {
|
||||
// backoff() uses 1-based attempts.
|
||||
let backoff_attempt = attempt + 1;
|
||||
match self {
|
||||
Self::RetryableHttpError { retry_after, .. } => {
|
||||
retry_after.unwrap_or_else(|| backoff(backoff_attempt))
|
||||
}
|
||||
Self::RetryableTransportError { .. } => backoff(backoff_attempt),
|
||||
Self::Fatal(_) => {
|
||||
// Should not be called on Fatal errors.
|
||||
Duration::from_secs(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn into_error(self) -> CodexErr {
|
||||
match self {
|
||||
Self::RetryableHttpError { status, .. } => {
|
||||
if status == StatusCode::INTERNAL_SERVER_ERROR {
|
||||
CodexErr::InternalServerError
|
||||
} else {
|
||||
CodexErr::RetryLimit(status)
|
||||
}
|
||||
}
|
||||
Self::RetryableTransportError(error) => error,
|
||||
Self::Fatal(error) => error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct SseEvent {
|
||||
#[serde(rename = "type")]
|
||||
@@ -414,9 +474,6 @@ struct SseEvent {
|
||||
delta: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCreated {}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompleted {
|
||||
id: String,
|
||||
@@ -488,19 +545,44 @@ fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) {
|
||||
}
|
||||
|
||||
fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
let primary_used_percent = parse_header_f64(headers, "x-codex-primary-used-percent")?;
|
||||
let secondary_used_percent = parse_header_f64(headers, "x-codex-secondary-used-percent")?;
|
||||
let primary_to_secondary_ratio_percent =
|
||||
parse_header_f64(headers, "x-codex-primary-over-secondary-limit-percent")?;
|
||||
let primary_window_minutes = parse_header_u64(headers, "x-codex-primary-window-minutes")?;
|
||||
let secondary_window_minutes = parse_header_u64(headers, "x-codex-secondary-window-minutes")?;
|
||||
let primary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-primary-used-percent",
|
||||
"x-codex-primary-window-minutes",
|
||||
"x-codex-primary-reset-after-seconds",
|
||||
);
|
||||
|
||||
Some(RateLimitSnapshot {
|
||||
primary_used_percent,
|
||||
secondary_used_percent,
|
||||
primary_to_secondary_ratio_percent,
|
||||
primary_window_minutes,
|
||||
secondary_window_minutes,
|
||||
let secondary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-secondary-used-percent",
|
||||
"x-codex-secondary-window-minutes",
|
||||
"x-codex-secondary-reset-after-seconds",
|
||||
);
|
||||
|
||||
Some(RateLimitSnapshot { primary, secondary })
|
||||
}
|
||||
|
||||
fn parse_rate_limit_window(
|
||||
headers: &HeaderMap,
|
||||
used_percent_header: &str,
|
||||
window_minutes_header: &str,
|
||||
resets_header: &str,
|
||||
) -> Option<RateLimitWindow> {
|
||||
let used_percent: Option<f64> = parse_header_f64(headers, used_percent_header);
|
||||
|
||||
used_percent.and_then(|used_percent| {
|
||||
let window_minutes = parse_header_u64(headers, window_minutes_header);
|
||||
let resets_in_seconds = parse_header_u64(headers, resets_header);
|
||||
|
||||
let has_data = used_percent != 0.0
|
||||
|| window_minutes.is_some_and(|minutes| minutes != 0)
|
||||
|| resets_in_seconds.is_some_and(|seconds| seconds != 0);
|
||||
|
||||
has_data.then_some(RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes,
|
||||
resets_in_seconds,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::AgentTask;
|
||||
use super::Session;
|
||||
use super::TurnContext;
|
||||
use super::get_last_assistant_message_from_turn;
|
||||
@@ -15,7 +14,6 @@ use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
use crate::protocol::TurnContextItem;
|
||||
use crate::truncate::truncate_middle;
|
||||
@@ -37,17 +35,7 @@ struct HistoryBridgeTemplate<'a> {
|
||||
summary_text: &'a str,
|
||||
}
|
||||
|
||||
pub(super) async fn spawn_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
let task = AgentTask::compact(sess.clone(), turn_context, sub_id, input);
|
||||
sess.set_task(task).await;
|
||||
}
|
||||
|
||||
pub(super) async fn run_inline_auto_compact_task(
|
||||
pub(crate) async fn run_inline_auto_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
) {
|
||||
@@ -55,15 +43,15 @@ pub(super) async fn run_inline_auto_compact_task(
|
||||
let input = vec![InputItem::Text {
|
||||
text: SUMMARIZATION_PROMPT.to_string(),
|
||||
}];
|
||||
run_compact_task_inner(sess, turn_context, sub_id, input, false).await;
|
||||
run_compact_task_inner(sess, turn_context, sub_id, input).await;
|
||||
}
|
||||
|
||||
pub(super) async fn run_compact_task(
|
||||
pub(crate) async fn run_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
) -> Option<String> {
|
||||
let start_event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
||||
@@ -71,14 +59,8 @@ pub(super) async fn run_compact_task(
|
||||
}),
|
||||
};
|
||||
sess.send_event(start_event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input, true).await;
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: None,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input).await;
|
||||
None
|
||||
}
|
||||
|
||||
async fn run_compact_task_inner(
|
||||
@@ -86,7 +68,6 @@ async fn run_compact_task_inner(
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
remove_task_on_completion: bool,
|
||||
) {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let turn_input = sess
|
||||
@@ -112,7 +93,8 @@ async fn run_compact_task_inner(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
loop {
|
||||
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
||||
let attempt_result =
|
||||
drain_to_completed(&sess, turn_context.as_ref(), &sub_id, &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => {
|
||||
@@ -148,21 +130,12 @@ async fn run_compact_task_inner(
|
||||
}
|
||||
}
|
||||
|
||||
if remove_task_on_completion {
|
||||
sess.remove_task(&sub_id).await;
|
||||
}
|
||||
let history_snapshot = {
|
||||
let state = sess.state.lock().await;
|
||||
state.history.contents()
|
||||
};
|
||||
let history_snapshot = sess.history_snapshot().await;
|
||||
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
|
||||
let user_messages = collect_user_messages(&history_snapshot);
|
||||
let initial_context = sess.build_initial_context(turn_context.as_ref());
|
||||
let new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
|
||||
{
|
||||
let mut state = sess.state.lock().await;
|
||||
state.history.replace(new_history);
|
||||
}
|
||||
sess.replace_history(new_history).await;
|
||||
|
||||
let rollout_item = RolloutItem::Compacted(CompactedItem {
|
||||
message: summary_text.clone(),
|
||||
@@ -257,6 +230,7 @@ pub(crate) fn build_compacted_history(
|
||||
async fn drain_to_completed(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
@@ -270,10 +244,14 @@ async fn drain_to_completed(
|
||||
};
|
||||
match event {
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
let mut state = sess.state.lock().await;
|
||||
state.history.record_items(std::slice::from_ref(&item));
|
||||
sess.record_into_history(std::slice::from_ref(&item)).await;
|
||||
}
|
||||
Ok(ResponseEvent::Completed { .. }) => {
|
||||
Ok(ResponseEvent::RateLimits(snapshot)) => {
|
||||
sess.update_rate_limits(sub_id, snapshot).await;
|
||||
}
|
||||
Ok(ResponseEvent::Completed { token_usage, .. }) => {
|
||||
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
Ok(_) => continue,
|
||||
|
||||
99
codex-rs/core/src/command_safety/is_dangerous_command.rs
Normal file
99
codex-rs/core/src/command_safety/is_dangerous_command.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use crate::bash::parse_bash_lc_plain_commands;
|
||||
|
||||
pub fn command_might_be_dangerous(command: &[String]) -> bool {
|
||||
if is_dangerous_to_call_with_exec(command) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Support `bash -lc "<script>"` where the any part of the script might contain a dangerous command.
|
||||
if let Some(all_commands) = parse_bash_lc_plain_commands(command)
|
||||
&& all_commands
|
||||
.iter()
|
||||
.any(|cmd| is_dangerous_to_call_with_exec(cmd))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn is_dangerous_to_call_with_exec(command: &[String]) -> bool {
|
||||
let cmd0 = command.first().map(String::as_str);
|
||||
|
||||
match cmd0 {
|
||||
Some(cmd) if cmd.ends_with("git") || cmd.ends_with("/git") => {
|
||||
matches!(command.get(1).map(String::as_str), Some("reset" | "rm"))
|
||||
}
|
||||
|
||||
Some("rm") => matches!(command.get(1).map(String::as_str), Some("-f" | "-rf")),
|
||||
|
||||
// for sudo <cmd> simply do the check for <cmd>
|
||||
Some("sudo") => is_dangerous_to_call_with_exec(&command[1..]),
|
||||
|
||||
// ── anything else ─────────────────────────────────────────────────
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn vec_str(items: &[&str]) -> Vec<String> {
|
||||
items.iter().map(std::string::ToString::to_string).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn git_reset_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&["git", "reset"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_git_reset_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"git reset --hard"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn git_status_is_not_dangerous() {
|
||||
assert!(!command_might_be_dangerous(&vec_str(&["git", "status"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_git_status_is_not_dangerous() {
|
||||
assert!(!command_might_be_dangerous(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"git status"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sudo_git_reset_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&[
|
||||
"sudo", "git", "reset", "--hard"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usr_bin_git_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&[
|
||||
"/usr/bin/git",
|
||||
"reset",
|
||||
"--hard"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rm_rf_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&["rm", "-rf", "/"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rm_f_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&["rm", "-f", "/"])));
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,14 @@
|
||||
use crate::bash::try_parse_bash;
|
||||
use crate::bash::try_parse_word_only_commands_sequence;
|
||||
use crate::bash::parse_bash_lc_plain_commands;
|
||||
|
||||
pub fn is_known_safe_command(command: &[String]) -> bool {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use super::windows_safe_commands::is_safe_command_windows;
|
||||
if is_safe_command_windows(command) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if is_safe_to_call_with_exec(command) {
|
||||
return true;
|
||||
}
|
||||
@@ -12,11 +19,7 @@ pub fn is_known_safe_command(command: &[String]) -> bool {
|
||||
// introduce side effects ( "&&", "||", ";", and "|" ). If every
|
||||
// individual command in the script is itself a known‑safe command, then
|
||||
// the composite expression is considered safe.
|
||||
if let [bash, flag, script] = command
|
||||
&& bash == "bash"
|
||||
&& flag == "-lc"
|
||||
&& let Some(tree) = try_parse_bash(script)
|
||||
&& let Some(all_commands) = try_parse_word_only_commands_sequence(&tree, script)
|
||||
if let Some(all_commands) = parse_bash_lc_plain_commands(command)
|
||||
&& !all_commands.is_empty()
|
||||
&& all_commands
|
||||
.iter()
|
||||
@@ -24,7 +27,6 @@ pub fn is_known_safe_command(command: &[String]) -> bool {
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
4
codex-rs/core/src/command_safety/mod.rs
Normal file
4
codex-rs/core/src/command_safety/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod is_dangerous_command;
|
||||
pub mod is_safe_command;
|
||||
#[cfg(target_os = "windows")]
|
||||
pub mod windows_safe_commands;
|
||||
25
codex-rs/core/src/command_safety/windows_safe_commands.rs
Normal file
25
codex-rs/core/src/command_safety/windows_safe_commands.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
// This is a WIP. This will eventually contain a real list of common safe Windows commands.
|
||||
pub fn is_safe_command_windows(_command: &[String]) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::is_safe_command_windows;
|
||||
|
||||
fn vec_str(args: &[&str]) -> Vec<String> {
|
||||
args.iter().map(ToString::to_string).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn everything_is_unsafe() {
|
||||
for cmd in [
|
||||
vec_str(&["powershell.exe", "-NoLogo", "-Command", "echo hello"]),
|
||||
vec_str(&["copy", "foo", "bar"]),
|
||||
vec_str(&["del", "file.txt"]),
|
||||
vec_str(&["powershell.exe", "Get-ChildItem"]),
|
||||
] {
|
||||
assert!(!is_safe_command_windows(&cmd));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -184,6 +184,10 @@ pub struct Config {
|
||||
/// If set to `true`, used only the experimental unified exec tool.
|
||||
pub use_experimental_unified_exec_tool: bool,
|
||||
|
||||
/// If set to `true`, use the experimental official Rust MCP client.
|
||||
/// https://github.com/modelcontextprotocol/rust-sdk
|
||||
pub use_experimental_use_rmcp_client: bool,
|
||||
|
||||
/// Include the `view_image` tool that lets the agent attach a local image path to context.
|
||||
pub include_view_image_tool: bool,
|
||||
|
||||
@@ -693,6 +697,7 @@ pub struct ConfigToml {
|
||||
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
pub experimental_use_rmcp_client: Option<bool>,
|
||||
|
||||
pub projects: Option<HashMap<String, ProjectConfig>>,
|
||||
|
||||
@@ -1043,6 +1048,7 @@ impl Config {
|
||||
use_experimental_unified_exec_tool: cfg
|
||||
.experimental_use_unified_exec_tool
|
||||
.unwrap_or(false),
|
||||
use_experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false),
|
||||
include_view_image_tool,
|
||||
active_profile: active_profile_name,
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
@@ -1651,6 +1657,7 @@ model_verbosity = "high"
|
||||
tools_web_search_request: false,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("o3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
@@ -1709,6 +1716,7 @@ model_verbosity = "high"
|
||||
tools_web_search_request: false,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
disable_paste_burst: false,
|
||||
@@ -1782,6 +1790,7 @@ model_verbosity = "high"
|
||||
tools_web_search_request: false,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("zdr".to_string()),
|
||||
disable_paste_burst: false,
|
||||
@@ -1841,6 +1850,7 @@ model_verbosity = "high"
|
||||
tools_web_search_request: false,
|
||||
use_experimental_streamable_shell_tool: false,
|
||||
use_experimental_unified_exec_tool: false,
|
||||
use_experimental_use_rmcp_client: false,
|
||||
include_view_image_tool: true,
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
disable_paste_burst: false,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use reqwest::header::HeaderValue;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::Mutex;
|
||||
@@ -20,7 +21,6 @@ use std::sync::Mutex;
|
||||
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
|
||||
|
||||
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Originator {
|
||||
pub value: String,
|
||||
@@ -112,17 +112,25 @@ pub fn create_client() -> reqwest::Client {
|
||||
headers.insert("originator", ORIGINATOR.header_value.clone());
|
||||
let ua = get_codex_user_agent();
|
||||
|
||||
reqwest::Client::builder()
|
||||
let mut builder = reqwest::Client::builder()
|
||||
// Set UA via dedicated helper to avoid header validation pitfalls
|
||||
.user_agent(ua)
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
.default_headers(headers);
|
||||
if is_sandboxed() {
|
||||
builder = builder.no_proxy();
|
||||
}
|
||||
|
||||
builder.build().unwrap_or_else(|_| reqwest::Client::new())
|
||||
}
|
||||
|
||||
fn is_sandboxed() -> bool {
|
||||
std::env::var(CODEX_SANDBOX_ENV_VAR).as_deref() == Ok("seatbelt")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use core_test_support::skip_if_no_network;
|
||||
|
||||
#[test]
|
||||
fn test_get_codex_user_agent() {
|
||||
@@ -132,6 +140,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_client_sets_default_headers() {
|
||||
skip_if_no_network!();
|
||||
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
|
||||
@@ -156,7 +156,7 @@ impl std::fmt::Display for UsageLimitReachedError {
|
||||
)
|
||||
}
|
||||
Some(PlanType::Known(KnownPlan::Free)) => {
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
"You've hit your usage limit. Upgrade to Plus to continue using Codex (https://openai.com/chatgpt/pricing)."
|
||||
.to_string()
|
||||
}
|
||||
Some(PlanType::Known(KnownPlan::Pro))
|
||||
@@ -267,14 +267,20 @@ pub fn get_error_message_ui(e: &CodexErr) -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
|
||||
fn rate_limit_snapshot() -> RateLimitSnapshot {
|
||||
RateLimitSnapshot {
|
||||
primary_used_percent: 0.5,
|
||||
secondary_used_percent: 0.3,
|
||||
primary_to_secondary_ratio_percent: 0.7,
|
||||
primary_window_minutes: 60,
|
||||
secondary_window_minutes: 120,
|
||||
primary: Some(RateLimitWindow {
|
||||
used_percent: 50.0,
|
||||
window_minutes: Some(60),
|
||||
resets_in_seconds: Some(3600),
|
||||
}),
|
||||
secondary: Some(RateLimitWindow {
|
||||
used_percent: 30.0,
|
||||
window_minutes: Some(120),
|
||||
resets_in_seconds: Some(7200),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,7 +306,7 @@ mod tests {
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
"You've hit your usage limit. Upgrade to Plus to continue using Codex (https://openai.com/chatgpt/pricing)."
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -12,4 +12,3 @@ pub use responses_api::WRITE_STDIN_TOOL_NAME;
|
||||
pub use responses_api::create_exec_command_tool_for_responses_api;
|
||||
pub use responses_api::create_write_stdin_tool_for_responses_api;
|
||||
pub use session_manager::SessionManager as ExecSessionManager;
|
||||
pub use session_manager::result_into_payload;
|
||||
|
||||
@@ -21,7 +21,6 @@ use crate::exec_command::exec_command_params::WriteStdinParams;
|
||||
use crate::exec_command::exec_command_session::ExecCommandSession;
|
||||
use crate::exec_command::session_id::SessionId;
|
||||
use crate::truncate::truncate_middle;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SessionManager {
|
||||
@@ -38,7 +37,7 @@ pub struct ExecCommandOutput {
|
||||
}
|
||||
|
||||
impl ExecCommandOutput {
|
||||
fn to_text_output(&self) -> String {
|
||||
pub(crate) fn to_text_output(&self) -> String {
|
||||
let wall_time_secs = self.wall_time.as_secs_f32();
|
||||
let termination_status = match self.exit_status {
|
||||
ExitStatus::Exited(code) => format!("Process exited with code {code}"),
|
||||
@@ -68,19 +67,6 @@ pub enum ExitStatus {
|
||||
Ongoing(SessionId),
|
||||
}
|
||||
|
||||
pub fn result_into_payload(result: Result<ExecCommandOutput, String>) -> FunctionCallOutputPayload {
|
||||
match result {
|
||||
Ok(output) => FunctionCallOutputPayload {
|
||||
content: output.to_text_output(),
|
||||
success: Some(true),
|
||||
},
|
||||
Err(err) => FunctionCallOutputPayload {
|
||||
content: err,
|
||||
success: Some(false),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
/// Processes the request and is required to send a response via `outgoing`.
|
||||
pub async fn handle_exec_command_request(
|
||||
|
||||
@@ -1,16 +1,6 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use env_flags::env_flags;
|
||||
|
||||
env_flags! {
|
||||
pub OPENAI_API_BASE: &str = "https://api.openai.com/v1";
|
||||
|
||||
/// Fallback when the provider-specific key is not set.
|
||||
pub OPENAI_API_KEY: Option<&str> = None;
|
||||
pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
|
||||
value.parse().map(Duration::from_millis)
|
||||
};
|
||||
|
||||
/// Fixture path for offline tests (see client.rs).
|
||||
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
|
||||
}
|
||||
|
||||
7
codex-rs/core/src/function_tool.rs
Normal file
7
codex-rs/core/src/function_tool.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error, PartialEq)]
|
||||
pub enum FunctionCallError {
|
||||
#[error("{0}")]
|
||||
RespondToModel(String),
|
||||
}
|
||||
@@ -589,6 +589,7 @@ pub async fn current_branch_name(cwd: &Path) -> Option<String> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use core_test_support::skip_if_sandbox;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
@@ -660,6 +661,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_recent_commits_orders_and_limits() {
|
||||
skip_if_sandbox!();
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ pub mod codex;
|
||||
mod codex_conversation;
|
||||
pub mod token_data;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
mod command_safety;
|
||||
pub mod config;
|
||||
pub mod config_edit;
|
||||
pub mod config_profile;
|
||||
@@ -29,7 +30,6 @@ pub mod exec_env;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod internal_storage;
|
||||
mod is_safe_command;
|
||||
pub mod landlock;
|
||||
mod mcp_connection_manager;
|
||||
mod mcp_tool_call;
|
||||
@@ -75,10 +75,14 @@ pub use rollout::find_conversation_path_by_id_str;
|
||||
pub use rollout::list::ConversationItem;
|
||||
pub use rollout::list::ConversationsPage;
|
||||
pub use rollout::list::Cursor;
|
||||
mod function_tool;
|
||||
mod state;
|
||||
mod tasks;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use command_safety::is_safe_command;
|
||||
pub use safety::get_platform_sandbox;
|
||||
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
|
||||
// `codex_core::protocol::...` references continue to work across the workspace.
|
||||
|
||||
@@ -16,6 +16,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_mcp_client::McpClient;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::Tool;
|
||||
@@ -86,11 +87,64 @@ struct ToolInfo {
|
||||
}
|
||||
|
||||
struct ManagedClient {
|
||||
client: Arc<McpClient>,
|
||||
client: McpClientAdapter,
|
||||
startup_timeout: Duration,
|
||||
tool_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum McpClientAdapter {
|
||||
Legacy(Arc<McpClient>),
|
||||
Rmcp(Arc<RmcpClient>),
|
||||
}
|
||||
|
||||
impl McpClientAdapter {
|
||||
async fn new_stdio_client(
|
||||
use_rmcp_client: bool,
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
tracing::error!(
|
||||
"new_stdio_client use_rmcp_client: {use_rmcp_client} program: {program:?} args: {args:?} env: {env:?} params: {params:?} startup_timeout: {startup_timeout:?}"
|
||||
);
|
||||
if use_rmcp_client {
|
||||
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
} else {
|
||||
let client = Arc::new(McpClient::new_stdio_client(program, args, env).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Legacy(client))
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_tools(
|
||||
&self,
|
||||
params: Option<mcp_types::ListToolsRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::ListToolsResult> {
|
||||
match self {
|
||||
McpClientAdapter::Legacy(client) => client.list_tools(params, timeout).await,
|
||||
McpClientAdapter::Rmcp(client) => client.list_tools(params, timeout).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
name: String,
|
||||
arguments: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::CallToolResult> {
|
||||
match self {
|
||||
McpClientAdapter::Legacy(client) => client.call_tool(name, arguments, timeout).await,
|
||||
McpClientAdapter::Rmcp(client) => client.call_tool(name, arguments, timeout).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`McpClient`] instances.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct McpConnectionManager {
|
||||
@@ -115,12 +169,15 @@ impl McpConnectionManager {
|
||||
/// user should be informed about these errors.
|
||||
pub async fn new(
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
use_rmcp_client: bool,
|
||||
) -> Result<(Self, ClientStartErrors)> {
|
||||
// Early exit if no servers are configured.
|
||||
if mcp_servers.is_empty() {
|
||||
return Ok((Self::default(), ClientStartErrors::default()));
|
||||
}
|
||||
|
||||
tracing::error!("new mcp_servers: {mcp_servers:?} use_rmcp_client: {use_rmcp_client}");
|
||||
|
||||
// Launch all configured servers concurrently.
|
||||
let mut join_set = JoinSet::new();
|
||||
let mut errors = ClientStartErrors::new();
|
||||
@@ -137,57 +194,48 @@ impl McpConnectionManager {
|
||||
}
|
||||
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
|
||||
let use_rmcp_client_flag = use_rmcp_client;
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig {
|
||||
command, args, env, ..
|
||||
} = cfg;
|
||||
let client_res = McpClient::new_stdio_client(
|
||||
command.into(),
|
||||
args.into_iter().map(OsString::from).collect(),
|
||||
let command_os: OsString = command.into();
|
||||
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
|
||||
// indicates this should be an empty object.
|
||||
elicitation: Some(json!({})),
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
// This field is used by Codex when it is an MCP
|
||||
// server: it should not be used when Codex is
|
||||
// an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
let client = McpClientAdapter::new_stdio_client(
|
||||
use_rmcp_client_flag,
|
||||
command_os,
|
||||
args_os,
|
||||
env,
|
||||
params,
|
||||
startup_timeout,
|
||||
)
|
||||
.await;
|
||||
match client_res {
|
||||
Ok(client) => {
|
||||
// Initialize the client.
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
|
||||
// indicates this should be an empty object.
|
||||
elicitation: Some(json!({})),
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
// This field is used by Codex when it is an MCP
|
||||
// server: it should not be used when Codex is
|
||||
// an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
let initialize_notification_params = None;
|
||||
let init_result = client
|
||||
.initialize(
|
||||
params,
|
||||
initialize_notification_params,
|
||||
Some(startup_timeout),
|
||||
)
|
||||
.await;
|
||||
(
|
||||
(server_name, tool_timeout),
|
||||
init_result.map(|_| (client, startup_timeout)),
|
||||
)
|
||||
}
|
||||
Err(e) => ((server_name, tool_timeout), Err(e.into())),
|
||||
}
|
||||
.await
|
||||
.map(|c| (c, startup_timeout));
|
||||
|
||||
((server_name, tool_timeout), client)
|
||||
});
|
||||
}
|
||||
|
||||
@@ -207,7 +255,7 @@ impl McpConnectionManager {
|
||||
clients.insert(
|
||||
server_name,
|
||||
ManagedClient {
|
||||
client: Arc::new(client),
|
||||
client,
|
||||
startup_timeout,
|
||||
tool_timeout: Some(tool_timeout),
|
||||
},
|
||||
|
||||
@@ -7,13 +7,14 @@ use crate::model_family::ModelFamily;
|
||||
/// Though this would help present more accurate pricing information in the UI.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ModelInfo {
|
||||
/// Size of the context window in tokens.
|
||||
/// Size of the context window in tokens. This is the maximum size of the input context.
|
||||
pub(crate) context_window: u64,
|
||||
|
||||
/// Maximum number of output tokens that can be generated for the model.
|
||||
pub(crate) max_output_tokens: u64,
|
||||
|
||||
/// Token threshold where we should automatically compact conversation history.
|
||||
/// Token threshold where we should automatically compact conversation history. This considers
|
||||
/// input tokens + output tokens of this turn.
|
||||
pub(crate) auto_compact_token_limit: Option<i64>,
|
||||
}
|
||||
|
||||
@@ -64,7 +65,7 @@ pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
_ if slug.starts_with("gpt-5-codex") => Some(ModelInfo {
|
||||
context_window: 272_000,
|
||||
max_output_tokens: 128_000,
|
||||
auto_compact_token_limit: Some(250_000),
|
||||
auto_compact_token_limit: Some(350_000),
|
||||
}),
|
||||
|
||||
_ if slug.starts_with("gpt-5") => Some(ModelInfo::new(272_000, 128_000)),
|
||||
|
||||
@@ -2,13 +2,12 @@ use std::collections::BTreeMap;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::openai_tools::JsonSchema;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::openai_tools::ResponsesApiTool;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
// Use the canonical plan tool types from the protocol crate to ensure
|
||||
// type-identity matches events transported via `codex_protocol`.
|
||||
@@ -67,44 +66,20 @@ pub(crate) async fn handle_update_plan(
|
||||
session: &Session,
|
||||
arguments: String,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> ResponseInputItem {
|
||||
match parse_update_plan_arguments(arguments, &call_id) {
|
||||
Ok(args) => {
|
||||
let output = ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "Plan updated".to_string(),
|
||||
success: Some(true),
|
||||
},
|
||||
};
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::PlanUpdate(args),
|
||||
})
|
||||
.await;
|
||||
output
|
||||
}
|
||||
Err(output) => *output,
|
||||
}
|
||||
_call_id: String,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
let args = parse_update_plan_arguments(&arguments)?;
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::PlanUpdate(args),
|
||||
})
|
||||
.await;
|
||||
Ok("Plan updated".to_string())
|
||||
}
|
||||
|
||||
fn parse_update_plan_arguments(
|
||||
arguments: String,
|
||||
call_id: &str,
|
||||
) -> Result<UpdatePlanArgs, Box<ResponseInputItem>> {
|
||||
match serde_json::from_str::<UpdatePlanArgs>(&arguments) {
|
||||
Ok(args) => Ok(args),
|
||||
Err(e) => {
|
||||
let output = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("failed to parse function arguments: {e}"),
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
Err(Box::new(output))
|
||||
}
|
||||
}
|
||||
fn parse_update_plan_arguments(arguments: &str) -> Result<UpdatePlanArgs, FunctionCallError> {
|
||||
serde_json::from_str::<UpdatePlanArgs>(arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e}"))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
@@ -28,7 +26,6 @@ use super::policy::is_persisted_response_item;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::ORIGINATOR;
|
||||
use crate::git_info::collect_git_info;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::ResumedHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
@@ -36,19 +33,6 @@ use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: ConversationId,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
/// every update.
|
||||
///
|
||||
|
||||
@@ -7,7 +7,9 @@ use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::ApplyPatchFileChange;
|
||||
|
||||
use crate::exec::SandboxType;
|
||||
use crate::is_safe_command::is_known_safe_command;
|
||||
|
||||
use crate::command_safety::is_dangerous_command::command_might_be_dangerous;
|
||||
use crate::command_safety::is_safe_command::is_known_safe_command;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
@@ -85,6 +87,20 @@ pub fn assess_command_safety(
|
||||
approved: &HashSet<Vec<String>>,
|
||||
with_escalated_permissions: bool,
|
||||
) -> SafetyCheck {
|
||||
// Some commands look dangerous. Even if they are run inside a sandbox,
|
||||
// unless the user has explicitly approved them, we should ask,
|
||||
// or reject if the approval_policy tells us not to ask.
|
||||
if command_might_be_dangerous(command) && !approved.contains(command) {
|
||||
if approval_policy == AskForApproval::Never {
|
||||
return SafetyCheck::Reject {
|
||||
reason: "dangerous command detected; rejected by user approval settings"
|
||||
.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
return SafetyCheck::AskUser;
|
||||
}
|
||||
|
||||
// A command is "trusted" because either:
|
||||
// - it belongs to a set of commands we consider "safe" by default, or
|
||||
// - the user has explicitly approved the command for this session
|
||||
@@ -98,6 +114,7 @@ pub fn assess_command_safety(
|
||||
// would probably be fine to run the command in a sandbox, but when
|
||||
// `approved.contains(command)` is `true`, the user may have approved it for
|
||||
// the session _because_ they know it needs to run outside a sandbox.
|
||||
|
||||
if is_known_safe_command(command) || approved.contains(command) {
|
||||
return SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
@@ -325,6 +342,56 @@ mod tests {
|
||||
assert_eq!(safety_check, SafetyCheck::AskUser);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dangerous_command_allowed_if_explicitly_approved() {
|
||||
let command = vec!["git".to_string(), "reset".to_string(), "--hard".to_string()];
|
||||
let approval_policy = AskForApproval::OnRequest;
|
||||
let sandbox_policy = SandboxPolicy::ReadOnly;
|
||||
let mut approved: HashSet<Vec<String>> = HashSet::new();
|
||||
approved.insert(command.clone());
|
||||
let request_escalated_privileges = false;
|
||||
|
||||
let safety_check = assess_command_safety(
|
||||
&command,
|
||||
approval_policy,
|
||||
&sandbox_policy,
|
||||
&approved,
|
||||
request_escalated_privileges,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
safety_check,
|
||||
SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dangerous_command_not_allowed_if_not_explicitly_approved() {
|
||||
let command = vec!["git".to_string(), "reset".to_string(), "--hard".to_string()];
|
||||
let approval_policy = AskForApproval::Never;
|
||||
let sandbox_policy = SandboxPolicy::ReadOnly;
|
||||
let approved: HashSet<Vec<String>> = HashSet::new();
|
||||
let request_escalated_privileges = false;
|
||||
|
||||
let safety_check = assess_command_safety(
|
||||
&command,
|
||||
approval_policy,
|
||||
&sandbox_policy,
|
||||
&approved,
|
||||
request_escalated_privileges,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
safety_check,
|
||||
SafetyCheck::Reject {
|
||||
reason: "dangerous command detected; rejected by user approval settings"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_escalated_privileges_no_sandbox_fallback() {
|
||||
let command = vec!["git".to_string(), "commit".to_string()];
|
||||
|
||||
9
codex-rs/core/src/state/mod.rs
Normal file
9
codex-rs/core/src/state/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod service;
|
||||
mod session;
|
||||
mod turn;
|
||||
|
||||
pub(crate) use service::SessionServices;
|
||||
pub(crate) use session::SessionState;
|
||||
pub(crate) use turn::ActiveTurn;
|
||||
pub(crate) use turn::RunningTask;
|
||||
pub(crate) use turn::TaskKind;
|
||||
18
codex-rs/core/src/state/service.rs
Normal file
18
codex-rs/core/src/state/service.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use crate::RolloutRecorder;
|
||||
use crate::exec_command::ExecSessionManager;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub(crate) struct SessionServices {
|
||||
pub(crate) mcp_connection_manager: McpConnectionManager,
|
||||
pub(crate) session_manager: ExecSessionManager,
|
||||
pub(crate) unified_exec_manager: UnifiedExecSessionManager,
|
||||
pub(crate) notifier: UserNotifier,
|
||||
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
|
||||
pub(crate) codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub(crate) user_shell: crate::shell::Shell,
|
||||
pub(crate) show_raw_agent_reasoning: bool,
|
||||
}
|
||||
80
codex-rs/core/src/state/session.rs
Normal file
80
codex-rs/core/src/state/session.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
//! Session-wide mutable state.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::protocol::TokenUsageInfo;
|
||||
|
||||
/// Persistent, session-scoped state previously stored directly on `Session`.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct SessionState {
|
||||
pub(crate) approved_commands: HashSet<Vec<String>>,
|
||||
pub(crate) history: ConversationHistory,
|
||||
pub(crate) token_info: Option<TokenUsageInfo>,
|
||||
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
/// Create a new session state mirroring previous `State::default()` semantics.
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
history: ConversationHistory::new(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
// History helpers
|
||||
pub(crate) fn record_items<I>(&mut self, items: I)
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: std::ops::Deref<Target = ResponseItem>,
|
||||
{
|
||||
self.history.record_items(items)
|
||||
}
|
||||
|
||||
pub(crate) fn history_snapshot(&self) -> Vec<ResponseItem> {
|
||||
self.history.contents()
|
||||
}
|
||||
|
||||
pub(crate) fn replace_history(&mut self, items: Vec<ResponseItem>) {
|
||||
self.history.replace(items);
|
||||
}
|
||||
|
||||
// Approved command helpers
|
||||
pub(crate) fn add_approved_command(&mut self, cmd: Vec<String>) {
|
||||
self.approved_commands.insert(cmd);
|
||||
}
|
||||
|
||||
pub(crate) fn approved_commands_ref(&self) -> &HashSet<Vec<String>> {
|
||||
&self.approved_commands
|
||||
}
|
||||
|
||||
// Token/rate limit helpers
|
||||
pub(crate) fn update_token_info_from_usage(
|
||||
&mut self,
|
||||
usage: &TokenUsage,
|
||||
model_context_window: Option<u64>,
|
||||
) {
|
||||
self.token_info = TokenUsageInfo::new_or_append(
|
||||
&self.token_info,
|
||||
&Some(usage.clone()),
|
||||
model_context_window,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn set_rate_limits(&mut self, snapshot: RateLimitSnapshot) {
|
||||
self.latest_rate_limits = Some(snapshot);
|
||||
}
|
||||
|
||||
pub(crate) fn token_info_and_rate_limits(
|
||||
&self,
|
||||
) -> (Option<TokenUsageInfo>, Option<RateLimitSnapshot>) {
|
||||
(self.token_info.clone(), self.latest_rate_limits.clone())
|
||||
}
|
||||
|
||||
// Pending input/approval moved to TurnState.
|
||||
}
|
||||
115
codex-rs/core/src/state/turn.rs
Normal file
115
codex-rs/core/src/state/turn.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
//! Turn-scoped state and active turn metadata scaffolding.
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::AbortHandle;
|
||||
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::tasks::SessionTask;
|
||||
|
||||
/// Metadata about the currently running turn.
|
||||
pub(crate) struct ActiveTurn {
|
||||
pub(crate) tasks: IndexMap<String, RunningTask>,
|
||||
pub(crate) turn_state: Arc<Mutex<TurnState>>,
|
||||
}
|
||||
|
||||
impl Default for ActiveTurn {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tasks: IndexMap::new(),
|
||||
turn_state: Arc::new(Mutex::new(TurnState::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub(crate) enum TaskKind {
|
||||
Regular,
|
||||
Review,
|
||||
Compact,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RunningTask {
|
||||
pub(crate) handle: AbortHandle,
|
||||
pub(crate) kind: TaskKind,
|
||||
pub(crate) task: Arc<dyn SessionTask>,
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
pub(crate) fn add_task(&mut self, sub_id: String, task: RunningTask) {
|
||||
self.tasks.insert(sub_id, task);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
|
||||
self.tasks.swap_remove(sub_id);
|
||||
self.tasks.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn drain_tasks(&mut self) -> IndexMap<String, RunningTask> {
|
||||
std::mem::take(&mut self.tasks)
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable state for a single turn.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct TurnState {
|
||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
}
|
||||
|
||||
impl TurnState {
|
||||
pub(crate) fn insert_pending_approval(
|
||||
&mut self,
|
||||
key: String,
|
||||
tx: oneshot::Sender<ReviewDecision>,
|
||||
) -> Option<oneshot::Sender<ReviewDecision>> {
|
||||
self.pending_approvals.insert(key, tx)
|
||||
}
|
||||
|
||||
pub(crate) fn remove_pending_approval(
|
||||
&mut self,
|
||||
key: &str,
|
||||
) -> Option<oneshot::Sender<ReviewDecision>> {
|
||||
self.pending_approvals.remove(key)
|
||||
}
|
||||
|
||||
pub(crate) fn clear_pending(&mut self) {
|
||||
self.pending_approvals.clear();
|
||||
self.pending_input.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn push_pending_input(&mut self, input: ResponseInputItem) {
|
||||
self.pending_input.push(input);
|
||||
}
|
||||
|
||||
pub(crate) fn take_pending_input(&mut self) -> Vec<ResponseInputItem> {
|
||||
if self.pending_input.is_empty() {
|
||||
Vec::with_capacity(0)
|
||||
} else {
|
||||
let mut ret = Vec::new();
|
||||
std::mem::swap(&mut ret, &mut self.pending_input);
|
||||
ret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
/// Clear any pending approvals and input buffered for the current turn.
|
||||
pub(crate) async fn clear_pending(&self) {
|
||||
let mut ts = self.turn_state.lock().await;
|
||||
ts.clear_pending();
|
||||
}
|
||||
|
||||
/// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt).
|
||||
pub(crate) fn try_clear_pending_sync(&self) {
|
||||
if let Ok(mut ts) = self.turn_state.try_lock() {
|
||||
ts.clear_pending();
|
||||
}
|
||||
}
|
||||
}
|
||||
31
codex-rs/core/src/tasks/compact.rs
Normal file
31
codex-rs/core/src/tasks/compact.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::compact;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::state::TaskKind;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct CompactTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for CompactTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Compact
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Option<String> {
|
||||
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
|
||||
}
|
||||
}
|
||||
166
codex-rs/core/src/tasks/mod.rs
Normal file
166
codex-rs/core/src/tasks/mod.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
mod compact;
|
||||
mod regular;
|
||||
mod review;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TurnAbortReason;
|
||||
use crate::protocol::TurnAbortedEvent;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::RunningTask;
|
||||
use crate::state::TaskKind;
|
||||
|
||||
pub(crate) use compact::CompactTask;
|
||||
pub(crate) use regular::RegularTask;
|
||||
pub(crate) use review::ReviewTask;
|
||||
|
||||
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionTaskContext {
|
||||
session: Arc<Session>,
|
||||
}
|
||||
|
||||
impl SessionTaskContext {
|
||||
pub(crate) fn new(session: Arc<Session>) -> Self {
|
||||
Self { session }
|
||||
}
|
||||
|
||||
pub(crate) fn clone_session(&self) -> Arc<Session> {
|
||||
Arc::clone(&self.session)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait SessionTask: Send + Sync + 'static {
|
||||
fn kind(&self) -> TaskKind;
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Option<String>;
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
let _ = (session, sub_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn spawn_task<T: SessionTask>(
|
||||
self: &Arc<Self>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
task: T,
|
||||
) {
|
||||
self.abort_all_tasks(TurnAbortReason::Replaced).await;
|
||||
|
||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
|
||||
let handle = {
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
let ctx = Arc::clone(&turn_context);
|
||||
let task_for_run = Arc::clone(&task);
|
||||
let sub_clone = sub_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let last_agent_message = task_for_run
|
||||
.run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input)
|
||||
.await;
|
||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||
let sess = session_ctx.clone_session();
|
||||
sess.on_task_finished(sub_clone, last_agent_message).await;
|
||||
})
|
||||
.abort_handle()
|
||||
};
|
||||
|
||||
let running_task = RunningTask {
|
||||
handle,
|
||||
kind: task_kind,
|
||||
task,
|
||||
};
|
||||
self.register_new_active_task(sub_id, running_task).await;
|
||||
}
|
||||
|
||||
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
|
||||
for (sub_id, task) in self.take_all_running_tasks().await {
|
||||
self.handle_task_abort(sub_id, task, reason.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
self: &Arc<Self>,
|
||||
sub_id: String,
|
||||
last_agent_message: Option<String>,
|
||||
) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut()
|
||||
&& at.remove_task(&sub_id)
|
||||
{
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
}
|
||||
|
||||
async fn register_new_active_task(&self, sub_id: String, task: RunningTask) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let mut turn = ActiveTurn::default();
|
||||
turn.add_task(sub_id, task);
|
||||
*active = Some(turn);
|
||||
}
|
||||
|
||||
async fn take_all_running_tasks(&self) -> Vec<(String, RunningTask)> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
match active.take() {
|
||||
Some(mut at) => {
|
||||
at.clear_pending().await;
|
||||
let tasks = at.drain_tasks();
|
||||
tasks.into_iter().collect()
|
||||
}
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_task_abort(
|
||||
self: &Arc<Self>,
|
||||
sub_id: String,
|
||||
task: RunningTask,
|
||||
reason: TurnAbortReason,
|
||||
) {
|
||||
if task.handle.is_finished() {
|
||||
return;
|
||||
}
|
||||
|
||||
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
|
||||
let session_task = task.task;
|
||||
let handle = task.handle;
|
||||
handle.abort();
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
session_task.abort(session_ctx, &sub_id).await;
|
||||
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
32
codex-rs/core/src/tasks/regular.rs
Normal file
32
codex-rs/core/src/tasks/regular.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::run_task;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::state::TaskKind;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct RegularTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for RegularTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(sess, ctx, sub_id, input).await
|
||||
}
|
||||
}
|
||||
37
codex-rs/core/src/tasks/review.rs
Normal file
37
codex-rs/core/src/tasks/review.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::exit_review_mode;
|
||||
use crate::codex::run_task;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::state::TaskKind;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct ReviewTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for ReviewTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Review
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(sess, ctx, sub_id, input).await
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||
}
|
||||
}
|
||||
@@ -10,11 +10,6 @@ use crate::openai_tools::ResponsesApiTool;
|
||||
|
||||
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct ApplyPatchToolArgs {
|
||||
pub(crate) input: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ApplyPatchToolType {
|
||||
|
||||
@@ -404,6 +404,8 @@ async fn create_unified_exec_session(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[cfg(unix)]
|
||||
use core_test_support::skip_if_sandbox;
|
||||
|
||||
#[test]
|
||||
fn push_chunk_trims_only_excess_bytes() {
|
||||
@@ -425,6 +427,8 @@ mod tests {
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn unified_exec_persists_across_requests_jif() -> Result<(), UnifiedExecError> {
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
@@ -462,6 +466,8 @@ mod tests {
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn multi_unified_exec_sessions() -> Result<(), UnifiedExecError> {
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let shell_a = manager
|
||||
@@ -508,6 +514,8 @@ mod tests {
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn unified_exec_timeouts() -> Result<(), UnifiedExecError> {
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
@@ -601,6 +609,8 @@ mod tests {
|
||||
#[cfg(unix)]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn reusing_completed_session_returns_unknown_session() -> Result<(), UnifiedExecError> {
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let manager = UnifiedExecSessionManager::default();
|
||||
|
||||
let open_shell = manager
|
||||
|
||||
@@ -8,6 +8,7 @@ path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -9,6 +9,7 @@ use codex_core::config::ConfigToml;
|
||||
|
||||
pub mod responses;
|
||||
pub mod test_codex;
|
||||
pub mod test_codex_exec;
|
||||
|
||||
/// Returns a default `Config` whose on-disk state is confined to the provided
|
||||
/// temporary directory. Using a per-test directory keeps tests hermetic and
|
||||
@@ -128,20 +129,56 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sandbox_env_var() -> &'static str {
|
||||
codex_core::spawn::CODEX_SANDBOX_ENV_VAR
|
||||
}
|
||||
|
||||
pub fn sandbox_network_env_var() -> &'static str {
|
||||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! non_sandbox_test {
|
||||
// For tests that return ()
|
||||
macro_rules! skip_if_sandbox {
|
||||
() => {{
|
||||
if ::std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
|
||||
println!("Skipping test because it cannot execute when network is disabled in a Codex sandbox.");
|
||||
if ::std::env::var($crate::sandbox_env_var())
|
||||
== ::core::result::Result::Ok("seatbelt".to_string())
|
||||
{
|
||||
eprintln!(
|
||||
"{} is set to 'seatbelt', skipping test.",
|
||||
$crate::sandbox_env_var()
|
||||
);
|
||||
return;
|
||||
}
|
||||
}};
|
||||
// For tests that return Result<(), _>
|
||||
(result $(,)?) => {{
|
||||
if ::std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
|
||||
println!("Skipping test because it cannot execute when network is disabled in a Codex sandbox.");
|
||||
return ::core::result::Result::Ok(());
|
||||
($return_value:expr $(,)?) => {{
|
||||
if ::std::env::var($crate::sandbox_env_var())
|
||||
== ::core::result::Result::Ok("seatbelt".to_string())
|
||||
{
|
||||
eprintln!(
|
||||
"{} is set to 'seatbelt', skipping test.",
|
||||
$crate::sandbox_env_var()
|
||||
);
|
||||
return $return_value;
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! skip_if_no_network {
|
||||
() => {{
|
||||
if ::std::env::var($crate::sandbox_network_env_var()).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
}};
|
||||
($return_value:expr $(,)?) => {{
|
||||
if ::std::env::var($crate::sandbox_network_env_var()).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return $return_value;
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use serde_json::Value;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
@@ -121,6 +122,7 @@ where
|
||||
.and(path("/v1/responses"))
|
||||
.and(matcher)
|
||||
.respond_with(sse_response(body))
|
||||
.up_to_n_times(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
}
|
||||
@@ -131,3 +133,41 @@ pub async fn start_mock_server() -> MockServer {
|
||||
.start()
|
||||
.await
|
||||
}
|
||||
|
||||
/// Mounts a sequence of SSE response bodies and serves them in order for each
|
||||
/// POST to `/v1/responses`. Panics if more requests are received than bodies
|
||||
/// provided. Also asserts the exact number of expected calls.
|
||||
pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) {
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
struct SeqResponder {
|
||||
num_calls: AtomicUsize,
|
||||
responses: Vec<String>,
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
||||
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
||||
match self.responses.get(call_num) {
|
||||
Some(body) => ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_string(body.clone()),
|
||||
None => panic!("no response for {call_num}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let num_calls = bodies.len();
|
||||
let responder = SeqResponder {
|
||||
num_calls: AtomicUsize::new(0),
|
||||
responses: bodies,
|
||||
};
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(responder)
|
||||
.expect(num_calls as u64)
|
||||
.mount(server)
|
||||
.await;
|
||||
}
|
||||
|
||||
40
codex-rs/core/tests/common/test_codex_exec.rs
Normal file
40
codex-rs/core/tests/common/test_codex_exec.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::MockServer;
|
||||
|
||||
pub struct TestCodexExecBuilder {
|
||||
home: TempDir,
|
||||
cwd: TempDir,
|
||||
}
|
||||
|
||||
impl TestCodexExecBuilder {
|
||||
pub fn cmd(&self) -> assert_cmd::Command {
|
||||
let mut cmd = assert_cmd::Command::cargo_bin("codex-exec")
|
||||
.expect("should find binary for codex-exec");
|
||||
cmd.current_dir(self.cwd.path())
|
||||
.env("CODEX_HOME", self.home.path())
|
||||
.env("OPENAI_API_KEY", "dummy");
|
||||
cmd
|
||||
}
|
||||
pub fn cmd_with_server(&self, server: &MockServer) -> assert_cmd::Command {
|
||||
let mut cmd = self.cmd();
|
||||
let base = format!("{}/v1", server.uri());
|
||||
cmd.env("OPENAI_BASE_URL", base);
|
||||
cmd
|
||||
}
|
||||
|
||||
pub fn cwd_path(&self) -> &Path {
|
||||
self.cwd.path()
|
||||
}
|
||||
pub fn home_path(&self) -> &Path {
|
||||
self.home.path()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_codex_exec() -> TestCodexExecBuilder {
|
||||
TestCodexExecBuilder {
|
||||
home: TempDir::new().expect("create temp home"),
|
||||
cwd: TempDir::new().expect("create temp cwd"),
|
||||
}
|
||||
}
|
||||
66
codex-rs/core/tests/suite/abort_tasks.rs
Normal file
66
codex-rs/core/tests/suite/abort_tasks.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
|
||||
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
||||
/// function call, then interrupt the session and expect TurnAborted.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn interrupt_long_running_tool_emits_turn_aborted() {
|
||||
let command = vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"sleep 60".to_string(),
|
||||
];
|
||||
|
||||
let args = json!({
|
||||
"command": command,
|
||||
"timeout_ms": 60_000
|
||||
})
|
||||
.to_string();
|
||||
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]);
|
||||
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_once(&server, body_string_contains("start sleep"), body).await;
|
||||
|
||||
let codex = test_codex().build(&server).await.unwrap().codex;
|
||||
|
||||
let wait_timeout = Duration::from_secs(5);
|
||||
|
||||
// Kick off a turn that triggers the function call.
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "start sleep".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait until the exec begins to avoid a race, then interrupt.
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|ev| matches!(ev, EventMsg::ExecCommandBegin(_)),
|
||||
wait_timeout,
|
||||
)
|
||||
.await;
|
||||
|
||||
codex.submit(Op::Interrupt).await.unwrap();
|
||||
|
||||
// Expect TurnAborted soon after.
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|ev| matches!(ev, EventMsg::TurnAborted(_)),
|
||||
wait_timeout,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use assert_cmd::Command as AssertCommand;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::protocol::GitInfo;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tempfile::TempDir;
|
||||
@@ -21,7 +21,7 @@ use wiremock::matchers::path;
|
||||
/// 4. Ensures the response is received exactly once and contains "hi"
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn chat_mode_stream_cli() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let sse = concat!(
|
||||
@@ -97,7 +97,7 @@ async fn chat_mode_stream_cli() {
|
||||
/// received by a mock OpenAI Responses endpoint.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_cli_applies_experimental_instructions_file() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Start mock server which will capture the request and return a minimal
|
||||
// SSE stream for a single turn.
|
||||
@@ -185,7 +185,7 @@ async fn exec_cli_applies_experimental_instructions_file() {
|
||||
/// 4. Ensures the fixture content is correctly streamed through the CLI
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_api_stream_cli() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let fixture =
|
||||
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/cli_responses_fixture.sse");
|
||||
@@ -217,7 +217,7 @@ async fn responses_api_stream_cli() {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_creates_and_checks_session_file() {
|
||||
// Honor sandbox network restrictions for CI parity with the other tests.
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// 1. Temp home so we read/write isolated session files.
|
||||
let home = TempDir::new().unwrap();
|
||||
|
||||
@@ -21,8 +21,8 @@ use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::WebSearchAction;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use futures::StreamExt;
|
||||
@@ -127,7 +127,7 @@ fn write_auth_json(
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Create a fake rollout session file with prior user + system + assistant messages.
|
||||
let tmpdir = TempDir::new().unwrap();
|
||||
@@ -293,7 +293,7 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_conversation_id_and_model_headers_in_request() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
@@ -361,6 +361,7 @@ async fn includes_conversation_id_and_model_headers_in_request() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_base_instructions_override_in_request() {
|
||||
skip_if_no_network!();
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
@@ -418,7 +419,7 @@ async fn includes_base_instructions_override_in_request() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn chatgpt_auth_sends_correct_request() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
@@ -492,7 +493,7 @@ async fn chatgpt_auth_sends_correct_request() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
@@ -558,6 +559,7 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_user_instructions_message_in_request() {
|
||||
skip_if_no_network!();
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let first = ResponseTemplate::new(200)
|
||||
@@ -619,7 +621,7 @@ async fn includes_user_instructions_message_in_request() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
@@ -755,6 +757,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn token_count_includes_rate_limits_snapshot() {
|
||||
skip_if_no_network!();
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse_body = responses::sse(vec![responses::ev_completed_with_tokens("resp_rate", 123)]);
|
||||
@@ -763,9 +766,10 @@ async fn token_count_includes_rate_limits_snapshot() {
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.insert_header("x-codex-primary-used-percent", "12.5")
|
||||
.insert_header("x-codex-secondary-used-percent", "40.0")
|
||||
.insert_header("x-codex-primary-over-secondary-limit-percent", "75.0")
|
||||
.insert_header("x-codex-primary-window-minutes", "10")
|
||||
.insert_header("x-codex-secondary-window-minutes", "60")
|
||||
.insert_header("x-codex-primary-reset-after-seconds", "1800")
|
||||
.insert_header("x-codex-secondary-reset-after-seconds", "7200")
|
||||
.set_body_raw(sse_body, "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
@@ -811,11 +815,16 @@ async fn token_count_includes_rate_limits_snapshot() {
|
||||
json!({
|
||||
"info": null,
|
||||
"rate_limits": {
|
||||
"primary_used_percent": 12.5,
|
||||
"secondary_used_percent": 40.0,
|
||||
"primary_to_secondary_ratio_percent": 75.0,
|
||||
"primary_window_minutes": 10,
|
||||
"secondary_window_minutes": 60
|
||||
"primary": {
|
||||
"used_percent": 12.5,
|
||||
"window_minutes": 10,
|
||||
"resets_in_seconds": 1800
|
||||
},
|
||||
"secondary": {
|
||||
"used_percent": 40.0,
|
||||
"window_minutes": 60,
|
||||
"resets_in_seconds": 7200
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
@@ -853,11 +862,16 @@ async fn token_count_includes_rate_limits_snapshot() {
|
||||
"model_context_window": 272000
|
||||
},
|
||||
"rate_limits": {
|
||||
"primary_used_percent": 12.5,
|
||||
"secondary_used_percent": 40.0,
|
||||
"primary_to_secondary_ratio_percent": 75.0,
|
||||
"primary_window_minutes": 10,
|
||||
"secondary_window_minutes": 60
|
||||
"primary": {
|
||||
"used_percent": 12.5,
|
||||
"window_minutes": 10,
|
||||
"resets_in_seconds": 1800
|
||||
},
|
||||
"secondary": {
|
||||
"used_percent": 40.0,
|
||||
"window_minutes": 60,
|
||||
"resets_in_seconds": 7200
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
@@ -868,13 +882,27 @@ async fn token_count_includes_rate_limits_snapshot() {
|
||||
let final_snapshot = final_payload
|
||||
.rate_limits
|
||||
.expect("latest rate limit snapshot should be retained");
|
||||
assert_eq!(final_snapshot.primary_used_percent, 12.5);
|
||||
assert_eq!(
|
||||
final_snapshot
|
||||
.primary
|
||||
.as_ref()
|
||||
.map(|window| window.used_percent),
|
||||
Some(12.5)
|
||||
);
|
||||
assert_eq!(
|
||||
final_snapshot
|
||||
.primary
|
||||
.as_ref()
|
||||
.and_then(|window| window.resets_in_seconds),
|
||||
Some(1800)
|
||||
);
|
||||
|
||||
wait_for_event(&codex, |msg| matches!(msg, EventMsg::TaskComplete(_))).await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let response = ResponseTemplate::new(429)
|
||||
@@ -904,11 +932,16 @@ async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> {
|
||||
let codex = codex_fixture.codex.clone();
|
||||
|
||||
let expected_limits = json!({
|
||||
"primary_used_percent": 100.0,
|
||||
"secondary_used_percent": 87.5,
|
||||
"primary_to_secondary_ratio_percent": 95.0,
|
||||
"primary_window_minutes": 15,
|
||||
"secondary_window_minutes": 60
|
||||
"primary": {
|
||||
"used_percent": 100.0,
|
||||
"window_minutes": 15,
|
||||
"resets_in_seconds": null
|
||||
},
|
||||
"secondary": {
|
||||
"used_percent": 87.5,
|
||||
"window_minutes": 60,
|
||||
"resets_in_seconds": null
|
||||
}
|
||||
});
|
||||
|
||||
let submission_id = codex
|
||||
@@ -949,6 +982,7 @@ async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn azure_overrides_assign_properties_used_for_responses_url() {
|
||||
skip_if_no_network!();
|
||||
let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" };
|
||||
|
||||
// Mock server
|
||||
@@ -1025,6 +1059,7 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn env_var_overrides_loaded_auth() {
|
||||
skip_if_no_network!();
|
||||
let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" };
|
||||
|
||||
// Mock server
|
||||
@@ -1112,7 +1147,7 @@ fn create_dummy_codex_auth() -> CodexAuth {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||
// Skip under Codex sandbox network restrictions (mirrors other tests).
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Mock server that will receive three sequential requests and return the same SSE stream
|
||||
// each time: a few deltas, then a final assistant message, then completed.
|
||||
|
||||
@@ -10,6 +10,7 @@ use codex_core::protocol::Op;
|
||||
use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
@@ -20,7 +21,6 @@ use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use codex_core::codex::compact::SUMMARIZATION_PROMPT;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_completed_with_tokens;
|
||||
@@ -53,7 +53,7 @@ const DUMMY_CALL_ID: &str = "call-multi-auto";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn summarize_context_three_requests_and_instructions() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Set up a mock server that we can inspect after the run.
|
||||
let server = start_mock_server().await;
|
||||
@@ -270,7 +270,7 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
#[cfg_attr(windows, tokio::test(flavor = "multi_thread", worker_threads = 4))]
|
||||
#[cfg_attr(not(windows), tokio::test(flavor = "multi_thread", worker_threads = 2))]
|
||||
async fn auto_compact_runs_after_token_limit_hit() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
@@ -430,7 +430,7 @@ async fn auto_compact_runs_after_token_limit_hit() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_persists_rollout_entries() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
@@ -558,7 +558,7 @@ async fn auto_compact_persists_rollout_entries() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_stops_after_failed_attempt() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
@@ -679,7 +679,7 @@ async fn auto_compact_stops_after_failed_attempt() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_events() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ use codex_core::protocol::Op;
|
||||
use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
@@ -29,6 +30,8 @@ fn sse_completed(id: &str) -> String {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn fork_conversation_twice_drops_to_first_message() {
|
||||
skip_if_no_network!();
|
||||
|
||||
// Start a mock server that completes three turns.
|
||||
let server = MockServer::start().await;
|
||||
let sse = sse_completed("resp");
|
||||
|
||||
@@ -6,8 +6,8 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
@@ -30,8 +30,17 @@ const SCHEMA: &str = r#"
|
||||
"#;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn codex_returns_json_result() -> anyhow::Result<()> {
|
||||
non_sandbox_test!(result);
|
||||
async fn codex_returns_json_result_for_gpt5() -> anyhow::Result<()> {
|
||||
codex_returns_json_result("gpt-5".to_string()).await
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn codex_returns_json_result_for_gpt5_codex() -> anyhow::Result<()> {
|
||||
codex_returns_json_result("gpt-5-codex".to_string()).await
|
||||
}
|
||||
|
||||
async fn codex_returns_json_result(model: String) -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
@@ -72,7 +81,7 @@ async fn codex_returns_json_result() -> anyhow::Result<()> {
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: "gpt-5".to_string(),
|
||||
model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
// Aggregates all former standalone integration tests as modules.
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
mod abort_tasks;
|
||||
mod cli_stream;
|
||||
mod client;
|
||||
mod compact;
|
||||
@@ -12,6 +14,7 @@ mod live_cli;
|
||||
mod model_overrides;
|
||||
mod prompt_caching;
|
||||
mod review;
|
||||
mod rmcp_client;
|
||||
mod rollout_list_find;
|
||||
mod seatbelt;
|
||||
mod stream_error_allows_next_turn;
|
||||
|
||||
@@ -16,6 +16,7 @@ use codex_core::shell::Shell;
|
||||
use codex_core::shell::default_user_shell;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
@@ -67,6 +68,7 @@ fn assert_tool_names(body: &serde_json::Value, expected_names: &[&str]) {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn codex_mini_latest_tools() {
|
||||
skip_if_no_network!();
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
@@ -151,6 +153,7 @@ async fn codex_mini_latest_tools() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn prompt_tools_are_consistent_across_requests() {
|
||||
skip_if_no_network!();
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
@@ -234,6 +237,7 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn prefixes_context_and_instructions_once_and_consistently_across_requests() {
|
||||
skip_if_no_network!();
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
@@ -352,6 +356,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
|
||||
skip_if_no_network!();
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
@@ -479,6 +484,7 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn per_turn_overrides_keep_cached_prefix_and_key_constant() {
|
||||
skip_if_no_network!();
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
@@ -602,6 +608,7 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn send_user_turn_with_no_changes_does_not_send_environment_context() {
|
||||
skip_if_no_network!();
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
@@ -715,6 +722,7 @@ async fn send_user_turn_with_no_changes_does_not_send_environment_context() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn send_user_turn_with_changes_sends_environment_context() {
|
||||
skip_if_no_network!();
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
@@ -22,7 +22,7 @@ use codex_core::protocol::RolloutItem;
|
||||
use codex_core::protocol::RolloutLine;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id_from_str;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::wait_for_event;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::PathBuf;
|
||||
@@ -42,7 +42,7 @@ use wiremock::matchers::path;
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn review_op_emits_lifecycle_and_review_output() {
|
||||
// Skip under Codex sandbox network restrictions.
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Start mock Responses API server. Return a single assistant message whose
|
||||
// text is a JSON-encoded ReviewOutputEvent.
|
||||
@@ -167,7 +167,7 @@ async fn review_op_emits_lifecycle_and_review_output() {
|
||||
#[cfg_attr(windows, tokio::test(flavor = "multi_thread", worker_threads = 4))]
|
||||
#[cfg_attr(not(windows), tokio::test(flavor = "multi_thread", worker_threads = 2))]
|
||||
async fn review_op_with_plain_text_emits_review_fallback() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let sse_raw = r#"[
|
||||
{"type":"response.output_item.done", "item":{
|
||||
@@ -216,7 +216,7 @@ async fn review_op_with_plain_text_emits_review_fallback() {
|
||||
#[cfg_attr(windows, tokio::test(flavor = "multi_thread", worker_threads = 4))]
|
||||
#[cfg_attr(not(windows), tokio::test(flavor = "multi_thread", worker_threads = 2))]
|
||||
async fn review_does_not_emit_agent_message_on_structured_output() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let review_json = serde_json::json!({
|
||||
"findings": [
|
||||
@@ -288,7 +288,7 @@ async fn review_does_not_emit_agent_message_on_structured_output() {
|
||||
/// request uses that model (and not the main chat model).
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn review_uses_custom_review_model_from_config() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Minimal stream: just a completed event
|
||||
let sse_raw = r#"[
|
||||
@@ -341,7 +341,7 @@ async fn review_uses_custom_review_model_from_config() {
|
||||
#[cfg_attr(windows, tokio::test(flavor = "multi_thread", worker_threads = 4))]
|
||||
#[cfg_attr(not(windows), tokio::test(flavor = "multi_thread", worker_threads = 2))]
|
||||
async fn review_input_isolated_from_parent_history() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Mock server for the single review request
|
||||
let sse_raw = r#"[
|
||||
@@ -517,7 +517,7 @@ async fn review_input_isolated_from_parent_history() {
|
||||
/// messages in its request `input`.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn review_history_does_not_leak_into_parent_session() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Respond to both the review request and the subsequent parent request.
|
||||
let sse_raw = r#"[
|
||||
|
||||
162
codex-rs/core/tests/suite/rmcp_client.rs
Normal file
162
codex-rs/core/tests/suite/rmcp_client.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use escargot::CargoBuild;
|
||||
use serde_json::Value;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let call_id = "call-123";
|
||||
let server_name = "rmcp";
|
||||
let tool_name = format!("{server_name}__echo");
|
||||
|
||||
mount_sse_once(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
mount_sse_once(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_env_value = "propagated-env";
|
||||
let rmcp_test_server_bin = CargoBuild::new()
|
||||
.package("codex-rmcp-client")
|
||||
.bin("rmcp_test_server")
|
||||
.run()?
|
||||
.path()
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.use_experimental_use_rmcp_client = true;
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
command: rmcp_test_server_bin.clone(),
|
||||
args: Vec::new(),
|
||||
env: Some(HashMap::from([(
|
||||
"MCP_TEST_VALUE".to_string(),
|
||||
expected_env_value.to_string(),
|
||||
)])),
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
fixture
|
||||
.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "call the rmcp echo tool".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
eprintln!("waiting for mcp tool call begin event");
|
||||
let begin_event = wait_for_event_with_timeout(
|
||||
&fixture.codex,
|
||||
|ev| {
|
||||
eprintln!("ev: {ev:?}");
|
||||
matches!(ev, EventMsg::McpToolCallBegin(_))
|
||||
},
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
eprintln!("mcp tool call begin event: {begin_event:?}");
|
||||
let EventMsg::McpToolCallBegin(begin) = begin_event else {
|
||||
unreachable!("event guard guarantees McpToolCallBegin");
|
||||
};
|
||||
assert_eq!(begin.invocation.server, server_name);
|
||||
assert_eq!(begin.invocation.tool, "echo");
|
||||
|
||||
let end_event = wait_for_event(&fixture.codex, |ev| {
|
||||
matches!(ev, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
eprintln!("end_event: {end_event:?}");
|
||||
let EventMsg::McpToolCallEnd(end) = end_event else {
|
||||
unreachable!("event guard guarantees McpToolCallEnd");
|
||||
};
|
||||
|
||||
let result = end
|
||||
.result
|
||||
.as_ref()
|
||||
.expect("rmcp echo tool should return success");
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert!(
|
||||
result.content.is_empty(),
|
||||
"content should default to an empty array"
|
||||
);
|
||||
|
||||
let structured = result
|
||||
.structured_content
|
||||
.as_ref()
|
||||
.expect("structured content");
|
||||
let Value::Object(map) = structured else {
|
||||
panic!("structured content should be an object: {structured:?}");
|
||||
};
|
||||
let echo_value = map
|
||||
.get("echo")
|
||||
.and_then(Value::as_str)
|
||||
.expect("echo payload present");
|
||||
assert_eq!(echo_value, "ping");
|
||||
let env_value = map
|
||||
.get("env")
|
||||
.and_then(Value::as_str)
|
||||
.expect("env snapshot inserted");
|
||||
assert_eq!(env_value, expected_env_value);
|
||||
|
||||
let task_complete_event =
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
eprintln!("task_complete_event: {task_complete_event:?}");
|
||||
|
||||
server.verify().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -6,7 +6,7 @@ use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
@@ -23,7 +23,7 @@ fn sse_completed(id: &str) -> String {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn continue_after_stream_error() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::load_sse_fixture;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use tokio::time::timeout;
|
||||
@@ -32,7 +32,7 @@ fn sse_completed(id: &str) -> String {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn retries_on_early_close() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ use std::os::unix::fs::PermissionsExt;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
@@ -22,7 +22,7 @@ use tokio::time::sleep;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn summarize_context_three_requests_and_instructions() -> anyhow::Result<()> {
|
||||
non_sandbox_test!(result);
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ codex-core = { workspace = true }
|
||||
codex-ollama = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
owo-colors = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
@@ -39,12 +40,18 @@ tokio = { workspace = true, features = [
|
||||
] }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
ts-rs = { workspace = true, features = [
|
||||
"uuid-impl",
|
||||
"serde-json-impl",
|
||||
"no-serde-warnings",
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
walkdir = { workspace = true }
|
||||
|
||||
@@ -64,9 +64,20 @@ pub struct Cli {
|
||||
pub color: Color,
|
||||
|
||||
/// Print events to stdout as JSONL.
|
||||
#[arg(long = "json", default_value_t = false)]
|
||||
#[arg(
|
||||
long = "json",
|
||||
default_value_t = false,
|
||||
conflicts_with = "experimental_json"
|
||||
)]
|
||||
pub json: bool,
|
||||
|
||||
#[arg(
|
||||
long = "experimental-json",
|
||||
default_value_t = false,
|
||||
conflicts_with = "json"
|
||||
)]
|
||||
pub experimental_json: bool,
|
||||
|
||||
/// Whether to include the plan tool in the conversation.
|
||||
#[arg(long = "include-plan-tool", default_value_t = false)]
|
||||
pub include_plan_tool: bool,
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::path::Path;
|
||||
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
|
||||
pub(crate) enum CodexStatus {
|
||||
Running,
|
||||
@@ -11,7 +12,12 @@ pub(crate) enum CodexStatus {
|
||||
|
||||
pub(crate) trait EventProcessor {
|
||||
/// Print summary of effective configuration and user prompt.
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str);
|
||||
fn print_config_summary(
|
||||
&mut self,
|
||||
config: &Config,
|
||||
prompt: &str,
|
||||
session_configured: &SessionConfiguredEvent,
|
||||
);
|
||||
|
||||
/// Handle a single event emitted by the agent.
|
||||
fn process_event(&mut self, event: Event) -> CodexStatus;
|
||||
|
||||
@@ -141,7 +141,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
/// Print a concise summary of the effective configuration that will be used
|
||||
/// for the session. This mirrors the information shown in the TUI welcome
|
||||
/// screen.
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str, _: &SessionConfiguredEvent) {
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
ts_println!(
|
||||
self,
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::path::PathBuf;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use serde_json::json;
|
||||
|
||||
@@ -23,7 +24,7 @@ impl EventProcessorWithJsonOutput {
|
||||
}
|
||||
|
||||
impl EventProcessor for EventProcessorWithJsonOutput {
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str, _: &SessionConfiguredEvent) {
|
||||
let entries = create_config_summary_entries(config)
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.to_string(), value))
|
||||
|
||||
191
codex-rs/exec/src/exec_events.rs
Normal file
191
codex-rs/exec/src/exec_events.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use ts_rs::TS;
|
||||
|
||||
/// Top-level events emitted on the Codex Exec conversation stream.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ConversationEvent {
|
||||
#[serde(rename = "session.created")]
|
||||
SessionCreated(SessionCreatedEvent),
|
||||
#[serde(rename = "turn.started")]
|
||||
TurnStarted(TurnStartedEvent),
|
||||
#[serde(rename = "turn.completed")]
|
||||
TurnCompleted(TurnCompletedEvent),
|
||||
#[serde(rename = "item.started")]
|
||||
ItemStarted(ItemStartedEvent),
|
||||
#[serde(rename = "item.updated")]
|
||||
ItemUpdated(ItemUpdatedEvent),
|
||||
#[serde(rename = "item.completed")]
|
||||
ItemCompleted(ItemCompletedEvent),
|
||||
#[serde(rename = "error")]
|
||||
Error(ConversationErrorEvent),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct SessionCreatedEvent {
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS, Default)]
|
||||
pub struct TurnStartedEvent {}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct TurnCompletedEvent {
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
/// Minimal usage summary for a turn.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS, Default)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: u64,
|
||||
pub cached_input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct ItemStartedEvent {
|
||||
pub item: ConversationItem,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct ItemCompletedEvent {
|
||||
pub item: ConversationItem,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct ItemUpdatedEvent {
|
||||
pub item: ConversationItem,
|
||||
}
|
||||
|
||||
/// Fatal error emitted by the stream.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct ConversationErrorEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Canonical representation of a conversation item and its domain-specific payload.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct ConversationItem {
|
||||
pub id: String,
|
||||
#[serde(flatten)]
|
||||
pub details: ConversationItemDetails,
|
||||
}
|
||||
|
||||
/// Typed payloads for each supported conversation item type.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
#[serde(tag = "item_type", rename_all = "snake_case")]
|
||||
pub enum ConversationItemDetails {
|
||||
AssistantMessage(AssistantMessageItem),
|
||||
Reasoning(ReasoningItem),
|
||||
CommandExecution(CommandExecutionItem),
|
||||
FileChange(FileChangeItem),
|
||||
McpToolCall(McpToolCallItem),
|
||||
WebSearch(WebSearchItem),
|
||||
TodoList(TodoListItem),
|
||||
Error(ErrorItem),
|
||||
}
|
||||
|
||||
/// Session conversation metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct SessionItem {
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
/// Assistant message payload.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct AssistantMessageItem {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Model reasoning summary payload.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct ReasoningItem {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CommandExecutionStatus {
|
||||
#[default]
|
||||
InProgress,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Local shell command execution payload.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct CommandExecutionItem {
|
||||
pub command: String,
|
||||
pub aggregated_output: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exit_code: Option<i32>,
|
||||
pub status: CommandExecutionStatus,
|
||||
}
|
||||
|
||||
/// Single file change summary for a patch.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct FileUpdateChange {
|
||||
pub path: String,
|
||||
pub kind: PatchChangeKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PatchApplyStatus {
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Patch application payload.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct FileChangeItem {
|
||||
pub changes: Vec<FileUpdateChange>,
|
||||
pub status: PatchApplyStatus,
|
||||
}
|
||||
|
||||
/// Known change kinds for a patch.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PatchChangeKind {
|
||||
Add,
|
||||
Delete,
|
||||
Update,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default, TS)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum McpToolCallStatus {
|
||||
#[default]
|
||||
InProgress,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct McpToolCallItem {
|
||||
pub server: String,
|
||||
pub tool: String,
|
||||
pub status: McpToolCallStatus,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct WebSearchItem {
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct ErrorItem {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct TodoItem {
|
||||
pub text: String,
|
||||
pub completed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, TS)]
|
||||
pub struct TodoListItem {
|
||||
pub items: Vec<TodoItem>,
|
||||
}
|
||||
@@ -0,0 +1,368 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::handle_last_message;
|
||||
use crate::exec_events::AssistantMessageItem;
|
||||
use crate::exec_events::CommandExecutionItem;
|
||||
use crate::exec_events::CommandExecutionStatus;
|
||||
use crate::exec_events::ConversationErrorEvent;
|
||||
use crate::exec_events::ConversationEvent;
|
||||
use crate::exec_events::ConversationItem;
|
||||
use crate::exec_events::ConversationItemDetails;
|
||||
use crate::exec_events::FileChangeItem;
|
||||
use crate::exec_events::FileUpdateChange;
|
||||
use crate::exec_events::ItemCompletedEvent;
|
||||
use crate::exec_events::ItemStartedEvent;
|
||||
use crate::exec_events::ItemUpdatedEvent;
|
||||
use crate::exec_events::PatchApplyStatus;
|
||||
use crate::exec_events::PatchChangeKind;
|
||||
use crate::exec_events::ReasoningItem;
|
||||
use crate::exec_events::SessionCreatedEvent;
|
||||
use crate::exec_events::TodoItem;
|
||||
use crate::exec_events::TodoListItem;
|
||||
use crate::exec_events::TurnCompletedEvent;
|
||||
use crate::exec_events::TurnStartedEvent;
|
||||
use crate::exec_events::Usage;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::plan_tool::StepStatus;
|
||||
use codex_core::plan_tool::UpdatePlanArgs;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::AgentReasoningEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecCommandBeginEvent;
|
||||
use codex_core::protocol::ExecCommandEndEvent;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::PatchApplyBeginEvent;
|
||||
use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::protocol::TaskStartedEvent;
|
||||
use tracing::error;
|
||||
use tracing::warn;
|
||||
|
||||
pub struct ExperimentalEventProcessorWithJsonOutput {
|
||||
last_message_path: Option<PathBuf>,
|
||||
next_event_id: AtomicU64,
|
||||
// Tracks running commands by call_id, including the associated item id.
|
||||
running_commands: HashMap<String, RunningCommand>,
|
||||
running_patch_applies: HashMap<String, PatchApplyBeginEvent>,
|
||||
// Tracks the todo list for the current turn (at most one per turn).
|
||||
running_todo_list: Option<RunningTodoList>,
|
||||
last_total_token_usage: Option<codex_core::protocol::TokenUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RunningCommand {
|
||||
command: String,
|
||||
item_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RunningTodoList {
|
||||
item_id: String,
|
||||
items: Vec<TodoItem>,
|
||||
}
|
||||
|
||||
impl ExperimentalEventProcessorWithJsonOutput {
|
||||
pub fn new(last_message_path: Option<PathBuf>) -> Self {
|
||||
Self {
|
||||
last_message_path,
|
||||
next_event_id: AtomicU64::new(0),
|
||||
running_commands: HashMap::new(),
|
||||
running_patch_applies: HashMap::new(),
|
||||
running_todo_list: None,
|
||||
last_total_token_usage: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn collect_conversation_events(&mut self, event: &Event) -> Vec<ConversationEvent> {
|
||||
match &event.msg {
|
||||
EventMsg::SessionConfigured(ev) => self.handle_session_configured(ev),
|
||||
EventMsg::AgentMessage(ev) => self.handle_agent_message(ev),
|
||||
EventMsg::AgentReasoning(ev) => self.handle_reasoning_event(ev),
|
||||
EventMsg::ExecCommandBegin(ev) => self.handle_exec_command_begin(ev),
|
||||
EventMsg::ExecCommandEnd(ev) => self.handle_exec_command_end(ev),
|
||||
EventMsg::PatchApplyBegin(ev) => self.handle_patch_apply_begin(ev),
|
||||
EventMsg::PatchApplyEnd(ev) => self.handle_patch_apply_end(ev),
|
||||
EventMsg::TokenCount(ev) => {
|
||||
if let Some(info) = &ev.info {
|
||||
self.last_total_token_usage = Some(info.total_token_usage.clone());
|
||||
}
|
||||
Vec::new()
|
||||
}
|
||||
EventMsg::TaskStarted(ev) => self.handle_task_started(ev),
|
||||
EventMsg::TaskComplete(_) => self.handle_task_complete(),
|
||||
EventMsg::Error(ev) => vec![ConversationEvent::Error(ConversationErrorEvent {
|
||||
message: ev.message.clone(),
|
||||
})],
|
||||
EventMsg::StreamError(ev) => vec![ConversationEvent::Error(ConversationErrorEvent {
|
||||
message: ev.message.clone(),
|
||||
})],
|
||||
EventMsg::PlanUpdate(ev) => self.handle_plan_update(ev),
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_next_item_id(&self) -> String {
|
||||
format!(
|
||||
"item_{}",
|
||||
self.next_event_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
|
||||
)
|
||||
}
|
||||
|
||||
fn handle_session_configured(
|
||||
&self,
|
||||
payload: &SessionConfiguredEvent,
|
||||
) -> Vec<ConversationEvent> {
|
||||
vec![ConversationEvent::SessionCreated(SessionCreatedEvent {
|
||||
session_id: payload.session_id.to_string(),
|
||||
})]
|
||||
}
|
||||
|
||||
fn handle_agent_message(&self, payload: &AgentMessageEvent) -> Vec<ConversationEvent> {
|
||||
let item = ConversationItem {
|
||||
id: self.get_next_item_id(),
|
||||
|
||||
details: ConversationItemDetails::AssistantMessage(AssistantMessageItem {
|
||||
text: payload.message.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
vec![ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item,
|
||||
})]
|
||||
}
|
||||
|
||||
fn handle_reasoning_event(&self, ev: &AgentReasoningEvent) -> Vec<ConversationEvent> {
|
||||
let item = ConversationItem {
|
||||
id: self.get_next_item_id(),
|
||||
|
||||
details: ConversationItemDetails::Reasoning(ReasoningItem {
|
||||
text: ev.text.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
vec![ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item,
|
||||
})]
|
||||
}
|
||||
fn handle_exec_command_begin(&mut self, ev: &ExecCommandBeginEvent) -> Vec<ConversationEvent> {
|
||||
let item_id = self.get_next_item_id();
|
||||
|
||||
let command_string = match shlex::try_join(ev.command.iter().map(String::as_str)) {
|
||||
Ok(command_string) => command_string,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
call_id = ev.call_id,
|
||||
"Failed to stringify command: {e:?}; skipping item.started"
|
||||
);
|
||||
ev.command.join(" ")
|
||||
}
|
||||
};
|
||||
|
||||
self.running_commands.insert(
|
||||
ev.call_id.clone(),
|
||||
RunningCommand {
|
||||
command: command_string.clone(),
|
||||
item_id: item_id.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
let item = ConversationItem {
|
||||
id: item_id,
|
||||
details: ConversationItemDetails::CommandExecution(CommandExecutionItem {
|
||||
command: command_string,
|
||||
aggregated_output: String::new(),
|
||||
exit_code: None,
|
||||
status: CommandExecutionStatus::InProgress,
|
||||
}),
|
||||
};
|
||||
|
||||
vec![ConversationEvent::ItemStarted(ItemStartedEvent { item })]
|
||||
}
|
||||
|
||||
fn handle_patch_apply_begin(&mut self, ev: &PatchApplyBeginEvent) -> Vec<ConversationEvent> {
|
||||
self.running_patch_applies
|
||||
.insert(ev.call_id.clone(), ev.clone());
|
||||
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn map_change_kind(&self, kind: &FileChange) -> PatchChangeKind {
|
||||
match kind {
|
||||
FileChange::Add { .. } => PatchChangeKind::Add,
|
||||
FileChange::Delete { .. } => PatchChangeKind::Delete,
|
||||
FileChange::Update { .. } => PatchChangeKind::Update,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_patch_apply_end(&mut self, ev: &PatchApplyEndEvent) -> Vec<ConversationEvent> {
|
||||
if let Some(running_patch_apply) = self.running_patch_applies.remove(&ev.call_id) {
|
||||
let status = if ev.success {
|
||||
PatchApplyStatus::Completed
|
||||
} else {
|
||||
PatchApplyStatus::Failed
|
||||
};
|
||||
let item = ConversationItem {
|
||||
id: self.get_next_item_id(),
|
||||
|
||||
details: ConversationItemDetails::FileChange(FileChangeItem {
|
||||
changes: running_patch_apply
|
||||
.changes
|
||||
.iter()
|
||||
.map(|(path, change)| FileUpdateChange {
|
||||
path: path.to_str().unwrap_or("").to_string(),
|
||||
kind: self.map_change_kind(change),
|
||||
})
|
||||
.collect(),
|
||||
status,
|
||||
}),
|
||||
};
|
||||
|
||||
return vec![ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item,
|
||||
})];
|
||||
}
|
||||
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn handle_exec_command_end(&mut self, ev: &ExecCommandEndEvent) -> Vec<ConversationEvent> {
|
||||
let Some(RunningCommand { command, item_id }) = self.running_commands.remove(&ev.call_id)
|
||||
else {
|
||||
warn!(
|
||||
call_id = ev.call_id,
|
||||
"ExecCommandEnd without matching ExecCommandBegin; skipping item.completed"
|
||||
);
|
||||
return Vec::new();
|
||||
};
|
||||
let status = if ev.exit_code == 0 {
|
||||
CommandExecutionStatus::Completed
|
||||
} else {
|
||||
CommandExecutionStatus::Failed
|
||||
};
|
||||
let item = ConversationItem {
|
||||
id: item_id,
|
||||
|
||||
details: ConversationItemDetails::CommandExecution(CommandExecutionItem {
|
||||
command,
|
||||
aggregated_output: ev.aggregated_output.clone(),
|
||||
exit_code: Some(ev.exit_code),
|
||||
status,
|
||||
}),
|
||||
};
|
||||
|
||||
vec![ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item,
|
||||
})]
|
||||
}
|
||||
|
||||
fn todo_items_from_plan(&self, args: &UpdatePlanArgs) -> Vec<TodoItem> {
|
||||
args.plan
|
||||
.iter()
|
||||
.map(|p| TodoItem {
|
||||
text: p.step.clone(),
|
||||
completed: matches!(p.status, StepStatus::Completed),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn handle_plan_update(&mut self, args: &UpdatePlanArgs) -> Vec<ConversationEvent> {
|
||||
let items = self.todo_items_from_plan(args);
|
||||
|
||||
if let Some(running) = &mut self.running_todo_list {
|
||||
running.items = items.clone();
|
||||
let item = ConversationItem {
|
||||
id: running.item_id.clone(),
|
||||
details: ConversationItemDetails::TodoList(TodoListItem { items }),
|
||||
};
|
||||
return vec![ConversationEvent::ItemUpdated(ItemUpdatedEvent { item })];
|
||||
}
|
||||
|
||||
let item_id = self.get_next_item_id();
|
||||
self.running_todo_list = Some(RunningTodoList {
|
||||
item_id: item_id.clone(),
|
||||
items: items.clone(),
|
||||
});
|
||||
let item = ConversationItem {
|
||||
id: item_id,
|
||||
details: ConversationItemDetails::TodoList(TodoListItem { items }),
|
||||
};
|
||||
vec![ConversationEvent::ItemStarted(ItemStartedEvent { item })]
|
||||
}
|
||||
|
||||
fn handle_task_started(&self, _: &TaskStartedEvent) -> Vec<ConversationEvent> {
|
||||
vec![ConversationEvent::TurnStarted(TurnStartedEvent {})]
|
||||
}
|
||||
|
||||
fn handle_task_complete(&mut self) -> Vec<ConversationEvent> {
|
||||
let usage = if let Some(u) = &self.last_total_token_usage {
|
||||
Usage {
|
||||
input_tokens: u.input_tokens,
|
||||
cached_input_tokens: u.cached_input_tokens,
|
||||
output_tokens: u.output_tokens,
|
||||
}
|
||||
} else {
|
||||
Usage::default()
|
||||
};
|
||||
|
||||
let mut items = Vec::new();
|
||||
|
||||
if let Some(running) = self.running_todo_list.take() {
|
||||
let item = ConversationItem {
|
||||
id: running.item_id,
|
||||
details: ConversationItemDetails::TodoList(TodoListItem {
|
||||
items: running.items,
|
||||
}),
|
||||
};
|
||||
items.push(ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item,
|
||||
}));
|
||||
}
|
||||
|
||||
items.push(ConversationEvent::TurnCompleted(TurnCompletedEvent {
|
||||
usage,
|
||||
}));
|
||||
|
||||
items
|
||||
}
|
||||
}
|
||||
|
||||
impl EventProcessor for ExperimentalEventProcessorWithJsonOutput {
|
||||
fn print_config_summary(&mut self, _: &Config, _: &str, ev: &SessionConfiguredEvent) {
|
||||
self.process_event(Event {
|
||||
id: "".to_string(),
|
||||
msg: EventMsg::SessionConfigured(ev.clone()),
|
||||
});
|
||||
}
|
||||
|
||||
fn process_event(&mut self, event: Event) -> CodexStatus {
|
||||
let aggregated = self.collect_conversation_events(&event);
|
||||
for conv_event in aggregated {
|
||||
match serde_json::to_string(&conv_event) {
|
||||
Ok(line) => {
|
||||
println!("{line}");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to serialize event: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let Event { msg, .. } = event;
|
||||
|
||||
if let EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) = msg {
|
||||
if let Some(output_file) = self.last_message_path.as_deref() {
|
||||
handle_last_message(last_agent_message.as_deref(), output_file);
|
||||
}
|
||||
CodexStatus::InitiateShutdown
|
||||
} else {
|
||||
CodexStatus::Running
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
mod cli;
|
||||
mod event_processor;
|
||||
mod event_processor_with_human_output;
|
||||
mod event_processor_with_json_output;
|
||||
pub mod event_processor_with_json_output;
|
||||
pub mod exec_events;
|
||||
pub mod experimental_event_processor_with_json_output;
|
||||
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Read;
|
||||
@@ -24,7 +26,7 @@ use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_ollama::DEFAULT_OSS_MODEL;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use event_processor_with_human_output::EventProcessorWithHumanOutput;
|
||||
use event_processor_with_json_output::EventProcessorWithJsonOutput;
|
||||
use experimental_event_processor_with_json_output::ExperimentalEventProcessorWithJsonOutput;
|
||||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
@@ -34,6 +36,7 @@ use tracing_subscriber::EnvFilter;
|
||||
use crate::cli::Command as ExecCommand;
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor_with_json_output::EventProcessorWithJsonOutput;
|
||||
use codex_core::find_conversation_path_by_id_str;
|
||||
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
@@ -50,6 +53,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
color,
|
||||
last_message_file,
|
||||
json: json_mode,
|
||||
experimental_json,
|
||||
sandbox_mode: sandbox_mode_cli_arg,
|
||||
prompt,
|
||||
output_schema: output_schema_path,
|
||||
@@ -178,14 +182,22 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
|
||||
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
|
||||
Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone()))
|
||||
} else {
|
||||
Box::new(EventProcessorWithHumanOutput::create_with_ansi(
|
||||
let mut event_processor: Box<dyn EventProcessor> = match (json_mode, experimental_json) {
|
||||
(_, true) => Box::new(ExperimentalEventProcessorWithJsonOutput::new(
|
||||
last_message_file.clone(),
|
||||
)),
|
||||
(true, _) => {
|
||||
eprintln!(
|
||||
"The existing `--json` output format is being deprecated. Please try the new format using `--experimental-json`."
|
||||
);
|
||||
|
||||
Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone()))
|
||||
}
|
||||
_ => Box::new(EventProcessorWithHumanOutput::create_with_ansi(
|
||||
stdout_with_ansi,
|
||||
&config,
|
||||
last_message_file.clone(),
|
||||
))
|
||||
)),
|
||||
};
|
||||
|
||||
if oss {
|
||||
@@ -194,10 +206,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
.map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?;
|
||||
}
|
||||
|
||||
// Print the effective configuration and prompt so users can see what Codex
|
||||
// is using.
|
||||
event_processor.print_config_summary(&config, &prompt);
|
||||
|
||||
let default_cwd = config.cwd.to_path_buf();
|
||||
let default_approval_policy = config.approval_policy;
|
||||
let default_sandbox_policy = config.sandbox_policy.clone();
|
||||
@@ -230,11 +238,19 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
conversation_manager.new_conversation(config).await?
|
||||
conversation_manager
|
||||
.new_conversation(config.clone())
|
||||
.await?
|
||||
}
|
||||
} else {
|
||||
conversation_manager.new_conversation(config).await?
|
||||
conversation_manager
|
||||
.new_conversation(config.clone())
|
||||
.await?
|
||||
};
|
||||
// Print the effective configuration and prompt so users can see what Codex
|
||||
// is using.
|
||||
event_processor.print_config_summary(&config, &prompt, &session_configured);
|
||||
|
||||
info!("Codex initialized with event: {session_configured:?}");
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Event>();
|
||||
@@ -315,7 +331,13 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
info!("Sent prompt with event ID: {initial_prompt_task_id}");
|
||||
|
||||
// Run the loop until the task is complete.
|
||||
// Track whether a fatal error was reported by the server so we can
|
||||
// exit with a non-zero status for automation-friendly signaling.
|
||||
let mut error_seen = false;
|
||||
while let Some(event) = rx.recv().await {
|
||||
if matches!(event.msg, EventMsg::Error(_)) {
|
||||
error_seen = true;
|
||||
}
|
||||
let shutdown: CodexStatus = event_processor.process_event(event);
|
||||
match shutdown {
|
||||
CodexStatus::Running => continue,
|
||||
@@ -327,6 +349,9 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
}
|
||||
}
|
||||
}
|
||||
if error_seen {
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Single integration test binary that aggregates all test modules.
|
||||
// The submodules live in `tests/suite/`.
|
||||
mod suite;
|
||||
|
||||
mod event_processor_with_json_output;
|
||||
|
||||
660
codex-rs/exec/tests/event_processor_with_json_output.rs
Normal file
660
codex-rs/exec/tests/event_processor_with_json_output.rs
Normal file
@@ -0,0 +1,660 @@
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::AgentReasoningEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecCommandBeginEvent;
|
||||
use codex_core::protocol::ExecCommandEndEvent;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::PatchApplyBeginEvent;
|
||||
use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_exec::exec_events::AssistantMessageItem;
|
||||
use codex_exec::exec_events::CommandExecutionItem;
|
||||
use codex_exec::exec_events::CommandExecutionStatus;
|
||||
use codex_exec::exec_events::ConversationErrorEvent;
|
||||
use codex_exec::exec_events::ConversationEvent;
|
||||
use codex_exec::exec_events::ConversationItem;
|
||||
use codex_exec::exec_events::ConversationItemDetails;
|
||||
use codex_exec::exec_events::ItemCompletedEvent;
|
||||
use codex_exec::exec_events::ItemStartedEvent;
|
||||
use codex_exec::exec_events::ItemUpdatedEvent;
|
||||
use codex_exec::exec_events::PatchApplyStatus;
|
||||
use codex_exec::exec_events::PatchChangeKind;
|
||||
use codex_exec::exec_events::ReasoningItem;
|
||||
use codex_exec::exec_events::SessionCreatedEvent;
|
||||
use codex_exec::exec_events::TodoItem as ExecTodoItem;
|
||||
use codex_exec::exec_events::TodoListItem as ExecTodoListItem;
|
||||
use codex_exec::exec_events::TurnCompletedEvent;
|
||||
use codex_exec::exec_events::TurnStartedEvent;
|
||||
use codex_exec::exec_events::Usage;
|
||||
use codex_exec::experimental_event_processor_with_json_output::ExperimentalEventProcessorWithJsonOutput;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
fn event(id: &str, msg: EventMsg) -> Event {
|
||||
Event {
|
||||
id: id.to_string(),
|
||||
msg,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_configured_produces_session_created_event() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
let session_id = codex_protocol::mcp_protocol::ConversationId::from_string(
|
||||
"67e55044-10b1-426f-9247-bb680e5fe0c8",
|
||||
)
|
||||
.unwrap();
|
||||
let rollout_path = PathBuf::from("/tmp/rollout.json");
|
||||
let ev = event(
|
||||
"e1",
|
||||
EventMsg::SessionConfigured(SessionConfiguredEvent {
|
||||
session_id,
|
||||
model: "codex-mini-latest".to_string(),
|
||||
reasoning_effort: None,
|
||||
history_log_id: 0,
|
||||
history_entry_count: 0,
|
||||
initial_messages: None,
|
||||
rollout_path,
|
||||
}),
|
||||
);
|
||||
let out = ep.collect_conversation_events(&ev);
|
||||
assert_eq!(
|
||||
out,
|
||||
vec![ConversationEvent::SessionCreated(SessionCreatedEvent {
|
||||
session_id: "67e55044-10b1-426f-9247-bb680e5fe0c8".to_string(),
|
||||
})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn task_started_produces_turn_started_event() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
let out = ep.collect_conversation_events(&event(
|
||||
"t1",
|
||||
EventMsg::TaskStarted(codex_core::protocol::TaskStartedEvent {
|
||||
model_context_window: Some(32_000),
|
||||
}),
|
||||
));
|
||||
|
||||
assert_eq!(
|
||||
out,
|
||||
vec![ConversationEvent::TurnStarted(TurnStartedEvent {})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_update_emits_todo_list_started_updated_and_completed() {
|
||||
use codex_core::plan_tool::PlanItemArg;
|
||||
use codex_core::plan_tool::StepStatus;
|
||||
use codex_core::plan_tool::UpdatePlanArgs;
|
||||
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
|
||||
// First plan update => item.started (todo_list)
|
||||
let first = event(
|
||||
"p1",
|
||||
EventMsg::PlanUpdate(UpdatePlanArgs {
|
||||
explanation: None,
|
||||
plan: vec![
|
||||
PlanItemArg {
|
||||
step: "step one".to_string(),
|
||||
status: StepStatus::Pending,
|
||||
},
|
||||
PlanItemArg {
|
||||
step: "step two".to_string(),
|
||||
status: StepStatus::InProgress,
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
let out_first = ep.collect_conversation_events(&first);
|
||||
assert_eq!(
|
||||
out_first,
|
||||
vec![ConversationEvent::ItemStarted(ItemStartedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::TodoList(ExecTodoListItem {
|
||||
items: vec![
|
||||
ExecTodoItem {
|
||||
text: "step one".to_string(),
|
||||
completed: false
|
||||
},
|
||||
ExecTodoItem {
|
||||
text: "step two".to_string(),
|
||||
completed: false
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
})]
|
||||
);
|
||||
|
||||
// Second plan update in same turn => item.updated (same id)
|
||||
let second = event(
|
||||
"p2",
|
||||
EventMsg::PlanUpdate(UpdatePlanArgs {
|
||||
explanation: None,
|
||||
plan: vec![
|
||||
PlanItemArg {
|
||||
step: "step one".to_string(),
|
||||
status: StepStatus::Completed,
|
||||
},
|
||||
PlanItemArg {
|
||||
step: "step two".to_string(),
|
||||
status: StepStatus::InProgress,
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
let out_second = ep.collect_conversation_events(&second);
|
||||
assert_eq!(
|
||||
out_second,
|
||||
vec![ConversationEvent::ItemUpdated(ItemUpdatedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::TodoList(ExecTodoListItem {
|
||||
items: vec![
|
||||
ExecTodoItem {
|
||||
text: "step one".to_string(),
|
||||
completed: true
|
||||
},
|
||||
ExecTodoItem {
|
||||
text: "step two".to_string(),
|
||||
completed: false
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
})]
|
||||
);
|
||||
|
||||
// Task completes => item.completed (same id, latest state)
|
||||
let complete = event(
|
||||
"p3",
|
||||
EventMsg::TaskComplete(codex_core::protocol::TaskCompleteEvent {
|
||||
last_agent_message: None,
|
||||
}),
|
||||
);
|
||||
let out_complete = ep.collect_conversation_events(&complete);
|
||||
assert_eq!(
|
||||
out_complete,
|
||||
vec![
|
||||
ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::TodoList(ExecTodoListItem {
|
||||
items: vec![
|
||||
ExecTodoItem {
|
||||
text: "step one".to_string(),
|
||||
completed: true
|
||||
},
|
||||
ExecTodoItem {
|
||||
text: "step two".to_string(),
|
||||
completed: false
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
}),
|
||||
ConversationEvent::TurnCompleted(TurnCompletedEvent {
|
||||
usage: Usage::default(),
|
||||
}),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_update_after_complete_starts_new_todo_list_with_new_id() {
|
||||
use codex_core::plan_tool::PlanItemArg;
|
||||
use codex_core::plan_tool::StepStatus;
|
||||
use codex_core::plan_tool::UpdatePlanArgs;
|
||||
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
|
||||
// First turn: start + complete
|
||||
let start = event(
|
||||
"t1",
|
||||
EventMsg::PlanUpdate(UpdatePlanArgs {
|
||||
explanation: None,
|
||||
plan: vec![PlanItemArg {
|
||||
step: "only".to_string(),
|
||||
status: StepStatus::Pending,
|
||||
}],
|
||||
}),
|
||||
);
|
||||
let _ = ep.collect_conversation_events(&start);
|
||||
let complete = event(
|
||||
"t2",
|
||||
EventMsg::TaskComplete(codex_core::protocol::TaskCompleteEvent {
|
||||
last_agent_message: None,
|
||||
}),
|
||||
);
|
||||
let _ = ep.collect_conversation_events(&complete);
|
||||
|
||||
// Second turn: a new todo list should have a new id
|
||||
let start_again = event(
|
||||
"t3",
|
||||
EventMsg::PlanUpdate(UpdatePlanArgs {
|
||||
explanation: None,
|
||||
plan: vec![PlanItemArg {
|
||||
step: "again".to_string(),
|
||||
status: StepStatus::Pending,
|
||||
}],
|
||||
}),
|
||||
);
|
||||
let out = ep.collect_conversation_events(&start_again);
|
||||
|
||||
match &out[0] {
|
||||
ConversationEvent::ItemStarted(ItemStartedEvent { item }) => {
|
||||
assert_eq!(&item.id, "item_1");
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_reasoning_produces_item_completed_reasoning() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
let ev = event(
|
||||
"e1",
|
||||
EventMsg::AgentReasoning(AgentReasoningEvent {
|
||||
text: "thinking...".to_string(),
|
||||
}),
|
||||
);
|
||||
let out = ep.collect_conversation_events(&ev);
|
||||
assert_eq!(
|
||||
out,
|
||||
vec![ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::Reasoning(ReasoningItem {
|
||||
text: "thinking...".to_string(),
|
||||
}),
|
||||
},
|
||||
})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_message_produces_item_completed_assistant_message() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
let ev = event(
|
||||
"e1",
|
||||
EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "hello".to_string(),
|
||||
}),
|
||||
);
|
||||
let out = ep.collect_conversation_events(&ev);
|
||||
assert_eq!(
|
||||
out,
|
||||
vec![ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::AssistantMessage(AssistantMessageItem {
|
||||
text: "hello".to_string(),
|
||||
}),
|
||||
},
|
||||
})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_event_produces_error() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
let out = ep.collect_conversation_events(&event(
|
||||
"e1",
|
||||
EventMsg::Error(codex_core::protocol::ErrorEvent {
|
||||
message: "boom".to_string(),
|
||||
}),
|
||||
));
|
||||
assert_eq!(
|
||||
out,
|
||||
vec![ConversationEvent::Error(ConversationErrorEvent {
|
||||
message: "boom".to_string(),
|
||||
})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_error_event_produces_error() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
let out = ep.collect_conversation_events(&event(
|
||||
"e1",
|
||||
EventMsg::StreamError(codex_core::protocol::StreamErrorEvent {
|
||||
message: "retrying".to_string(),
|
||||
}),
|
||||
));
|
||||
assert_eq!(
|
||||
out,
|
||||
vec![ConversationEvent::Error(ConversationErrorEvent {
|
||||
message: "retrying".to_string(),
|
||||
})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exec_command_end_success_produces_completed_command_item() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
|
||||
// Begin -> no output
|
||||
let begin = event(
|
||||
"c1",
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id: "1".to_string(),
|
||||
command: vec!["bash".to_string(), "-lc".to_string(), "echo hi".to_string()],
|
||||
cwd: std::env::current_dir().unwrap(),
|
||||
parsed_cmd: Vec::new(),
|
||||
}),
|
||||
);
|
||||
let out_begin = ep.collect_conversation_events(&begin);
|
||||
assert_eq!(
|
||||
out_begin,
|
||||
vec![ConversationEvent::ItemStarted(ItemStartedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::CommandExecution(CommandExecutionItem {
|
||||
command: "bash -lc 'echo hi'".to_string(),
|
||||
aggregated_output: String::new(),
|
||||
exit_code: None,
|
||||
status: CommandExecutionStatus::InProgress,
|
||||
}),
|
||||
},
|
||||
})]
|
||||
);
|
||||
|
||||
// End (success) -> item.completed (item_0)
|
||||
let end_ok = event(
|
||||
"c2",
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id: "1".to_string(),
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
aggregated_output: "hi\n".to_string(),
|
||||
exit_code: 0,
|
||||
duration: Duration::from_millis(5),
|
||||
formatted_output: String::new(),
|
||||
}),
|
||||
);
|
||||
let out_ok = ep.collect_conversation_events(&end_ok);
|
||||
assert_eq!(
|
||||
out_ok,
|
||||
vec![ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::CommandExecution(CommandExecutionItem {
|
||||
command: "bash -lc 'echo hi'".to_string(),
|
||||
aggregated_output: "hi\n".to_string(),
|
||||
exit_code: Some(0),
|
||||
status: CommandExecutionStatus::Completed,
|
||||
}),
|
||||
},
|
||||
})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exec_command_end_failure_produces_failed_command_item() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
|
||||
// Begin -> no output
|
||||
let begin = event(
|
||||
"c1",
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id: "2".to_string(),
|
||||
command: vec!["sh".to_string(), "-c".to_string(), "exit 1".to_string()],
|
||||
cwd: std::env::current_dir().unwrap(),
|
||||
parsed_cmd: Vec::new(),
|
||||
}),
|
||||
);
|
||||
assert_eq!(
|
||||
ep.collect_conversation_events(&begin),
|
||||
vec![ConversationEvent::ItemStarted(ItemStartedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::CommandExecution(CommandExecutionItem {
|
||||
command: "sh -c 'exit 1'".to_string(),
|
||||
aggregated_output: String::new(),
|
||||
exit_code: None,
|
||||
status: CommandExecutionStatus::InProgress,
|
||||
}),
|
||||
},
|
||||
})]
|
||||
);
|
||||
|
||||
// End (failure) -> item.completed (item_0)
|
||||
let end_fail = event(
|
||||
"c2",
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id: "2".to_string(),
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
aggregated_output: String::new(),
|
||||
exit_code: 1,
|
||||
duration: Duration::from_millis(2),
|
||||
formatted_output: String::new(),
|
||||
}),
|
||||
);
|
||||
let out_fail = ep.collect_conversation_events(&end_fail);
|
||||
assert_eq!(
|
||||
out_fail,
|
||||
vec![ConversationEvent::ItemCompleted(ItemCompletedEvent {
|
||||
item: ConversationItem {
|
||||
id: "item_0".to_string(),
|
||||
details: ConversationItemDetails::CommandExecution(CommandExecutionItem {
|
||||
command: "sh -c 'exit 1'".to_string(),
|
||||
aggregated_output: String::new(),
|
||||
exit_code: Some(1),
|
||||
status: CommandExecutionStatus::Failed,
|
||||
}),
|
||||
},
|
||||
})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exec_command_end_without_begin_is_ignored() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
|
||||
// End event arrives without a prior Begin; should produce no conversation events.
|
||||
let end_only = event(
|
||||
"c1",
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id: "no-begin".to_string(),
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
aggregated_output: String::new(),
|
||||
exit_code: 0,
|
||||
duration: Duration::from_millis(1),
|
||||
formatted_output: String::new(),
|
||||
}),
|
||||
);
|
||||
let out = ep.collect_conversation_events(&end_only);
|
||||
assert!(out.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn patch_apply_success_produces_item_completed_patchapply() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
|
||||
// Prepare a patch with multiple kinds of changes
|
||||
let mut changes = std::collections::HashMap::new();
|
||||
changes.insert(
|
||||
PathBuf::from("a/added.txt"),
|
||||
FileChange::Add {
|
||||
content: "+hello".to_string(),
|
||||
},
|
||||
);
|
||||
changes.insert(
|
||||
PathBuf::from("b/deleted.txt"),
|
||||
FileChange::Delete {
|
||||
content: "-goodbye".to_string(),
|
||||
},
|
||||
);
|
||||
changes.insert(
|
||||
PathBuf::from("c/modified.txt"),
|
||||
FileChange::Update {
|
||||
unified_diff: "--- c/modified.txt\n+++ c/modified.txt\n@@\n-old\n+new\n".to_string(),
|
||||
move_path: Some(PathBuf::from("c/renamed.txt")),
|
||||
},
|
||||
);
|
||||
|
||||
// Begin -> no output
|
||||
let begin = event(
|
||||
"p1",
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id: "call-1".to_string(),
|
||||
auto_approved: true,
|
||||
changes: changes.clone(),
|
||||
}),
|
||||
);
|
||||
let out_begin = ep.collect_conversation_events(&begin);
|
||||
assert!(out_begin.is_empty());
|
||||
|
||||
// End (success) -> item.completed (item_0)
|
||||
let end = event(
|
||||
"p2",
|
||||
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id: "call-1".to_string(),
|
||||
stdout: "applied 3 changes".to_string(),
|
||||
stderr: String::new(),
|
||||
success: true,
|
||||
}),
|
||||
);
|
||||
let out_end = ep.collect_conversation_events(&end);
|
||||
assert_eq!(out_end.len(), 1);
|
||||
|
||||
// Validate structure without relying on HashMap iteration order
|
||||
match &out_end[0] {
|
||||
ConversationEvent::ItemCompleted(ItemCompletedEvent { item }) => {
|
||||
assert_eq!(&item.id, "item_0");
|
||||
match &item.details {
|
||||
ConversationItemDetails::FileChange(file_update) => {
|
||||
assert_eq!(file_update.status, PatchApplyStatus::Completed);
|
||||
|
||||
let mut actual: Vec<(String, PatchChangeKind)> = file_update
|
||||
.changes
|
||||
.iter()
|
||||
.map(|c| (c.path.clone(), c.kind.clone()))
|
||||
.collect();
|
||||
actual.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut expected = vec![
|
||||
("a/added.txt".to_string(), PatchChangeKind::Add),
|
||||
("b/deleted.txt".to_string(), PatchChangeKind::Delete),
|
||||
("c/modified.txt".to_string(), PatchChangeKind::Update),
|
||||
];
|
||||
expected.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
other => panic!("unexpected details: {other:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn patch_apply_failure_produces_item_completed_patchapply_failed() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
|
||||
let mut changes = std::collections::HashMap::new();
|
||||
changes.insert(
|
||||
PathBuf::from("file.txt"),
|
||||
FileChange::Update {
|
||||
unified_diff: "--- file.txt\n+++ file.txt\n@@\n-old\n+new\n".to_string(),
|
||||
move_path: None,
|
||||
},
|
||||
);
|
||||
|
||||
// Begin -> no output
|
||||
let begin = event(
|
||||
"p1",
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id: "call-2".to_string(),
|
||||
auto_approved: false,
|
||||
changes: changes.clone(),
|
||||
}),
|
||||
);
|
||||
assert!(ep.collect_conversation_events(&begin).is_empty());
|
||||
|
||||
// End (failure) -> item.completed (item_0) with Failed status
|
||||
let end = event(
|
||||
"p2",
|
||||
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id: "call-2".to_string(),
|
||||
stdout: String::new(),
|
||||
stderr: "failed to apply".to_string(),
|
||||
success: false,
|
||||
}),
|
||||
);
|
||||
let out_end = ep.collect_conversation_events(&end);
|
||||
assert_eq!(out_end.len(), 1);
|
||||
|
||||
match &out_end[0] {
|
||||
ConversationEvent::ItemCompleted(ItemCompletedEvent { item }) => {
|
||||
assert_eq!(&item.id, "item_0");
|
||||
match &item.details {
|
||||
ConversationItemDetails::FileChange(file_update) => {
|
||||
assert_eq!(file_update.status, PatchApplyStatus::Failed);
|
||||
assert_eq!(file_update.changes.len(), 1);
|
||||
assert_eq!(file_update.changes[0].path, "file.txt".to_string());
|
||||
assert_eq!(file_update.changes[0].kind, PatchChangeKind::Update);
|
||||
}
|
||||
other => panic!("unexpected details: {other:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn task_complete_produces_turn_completed_with_usage() {
|
||||
let mut ep = ExperimentalEventProcessorWithJsonOutput::new(None);
|
||||
|
||||
// First, feed a TokenCount event with known totals.
|
||||
let usage = codex_core::protocol::TokenUsage {
|
||||
input_tokens: 1200,
|
||||
cached_input_tokens: 200,
|
||||
output_tokens: 345,
|
||||
reasoning_output_tokens: 0,
|
||||
total_tokens: 0,
|
||||
};
|
||||
let info = codex_core::protocol::TokenUsageInfo {
|
||||
total_token_usage: usage.clone(),
|
||||
last_token_usage: usage,
|
||||
model_context_window: None,
|
||||
};
|
||||
let token_count_event = event(
|
||||
"e1",
|
||||
EventMsg::TokenCount(codex_core::protocol::TokenCountEvent {
|
||||
info: Some(info),
|
||||
rate_limits: None,
|
||||
}),
|
||||
);
|
||||
assert!(
|
||||
ep.collect_conversation_events(&token_count_event)
|
||||
.is_empty()
|
||||
);
|
||||
|
||||
// Then TaskComplete should produce turn.completed with the captured usage.
|
||||
let complete_event = event(
|
||||
"e2",
|
||||
EventMsg::TaskComplete(codex_core::protocol::TaskCompleteEvent {
|
||||
last_agent_message: Some("done".to_string()),
|
||||
}),
|
||||
);
|
||||
let out = ep.collect_conversation_events(&complete_event);
|
||||
assert_eq!(
|
||||
out,
|
||||
vec![ConversationEvent::TurnCompleted(TurnCompletedEvent {
|
||||
usage: Usage {
|
||||
input_tokens: 1200,
|
||||
cached_input_tokens: 200,
|
||||
output_tokens: 345,
|
||||
},
|
||||
})]
|
||||
);
|
||||
}
|
||||
@@ -6,7 +6,9 @@ use codex_core::CODEX_APPLY_PATCH_ARG1;
|
||||
use core_test_support::responses::ev_apply_patch_custom_tool_call;
|
||||
use core_test_support::responses::ev_apply_patch_function_call;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
use tempfile::tempdir;
|
||||
@@ -47,13 +49,13 @@ fn test_standalone_exec_cli_can_use_apply_patch() -> anyhow::Result<()> {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_apply_patch_tool() -> anyhow::Result<()> {
|
||||
use crate::suite::common::run_e2e_exec_test;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
|
||||
non_sandbox_test!(result);
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let tmp_cwd = tempdir().expect("failed to create temp dir");
|
||||
let tmp_path = tmp_cwd.path().to_path_buf();
|
||||
let test = test_codex_exec();
|
||||
let tmp_path = test.cwd_path().to_path_buf();
|
||||
let add_patch = r#"*** Begin Patch
|
||||
*** Add File: test.md
|
||||
+Hello world
|
||||
@@ -75,7 +77,16 @@ async fn test_apply_patch_tool() -> anyhow::Result<()> {
|
||||
]),
|
||||
sse(vec![ev_completed("request_2")]),
|
||||
];
|
||||
run_e2e_exec_test(tmp_cwd.path(), response_streams).await;
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_sequence(&server, response_streams).await;
|
||||
|
||||
test.cmd_with_server(&server)
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-s")
|
||||
.arg("danger-full-access")
|
||||
.arg("foo")
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let final_path = tmp_path.join("test.md");
|
||||
let contents = std::fs::read_to_string(&final_path)
|
||||
@@ -87,12 +98,12 @@ async fn test_apply_patch_tool() -> anyhow::Result<()> {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_apply_patch_freeform_tool() -> anyhow::Result<()> {
|
||||
use crate::suite::common::run_e2e_exec_test;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
|
||||
non_sandbox_test!(result);
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let tmp_cwd = tempdir().expect("failed to create temp dir");
|
||||
let test = test_codex_exec();
|
||||
let freeform_add_patch = r#"*** Begin Patch
|
||||
*** Add File: app.py
|
||||
+class BaseClass:
|
||||
@@ -117,10 +128,19 @@ async fn test_apply_patch_freeform_tool() -> anyhow::Result<()> {
|
||||
]),
|
||||
sse(vec![ev_completed("request_2")]),
|
||||
];
|
||||
run_e2e_exec_test(tmp_cwd.path(), response_streams).await;
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_sequence(&server, response_streams).await;
|
||||
|
||||
test.cmd_with_server(&server)
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-s")
|
||||
.arg("danger-full-access")
|
||||
.arg("foo")
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
// Verify final file contents
|
||||
let final_path = tmp_cwd.path().join("app.py");
|
||||
let final_path = test.cwd_path().join("app.py");
|
||||
let contents = std::fs::read_to_string(&final_path)
|
||||
.unwrap_or_else(|e| panic!("failed reading {}: {e}", final_path.display()));
|
||||
assert_eq!(
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
// this file is only used for e2e tests which are currently disabled on windows
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use wiremock::Respond;
|
||||
|
||||
struct SeqResponder {
|
||||
num_calls: AtomicUsize,
|
||||
responses: Vec<String>,
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _: &wiremock::Request) -> wiremock::ResponseTemplate {
|
||||
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
||||
match self.responses.get(call_num) {
|
||||
Some(body) => wiremock::ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_string(body.clone()),
|
||||
None => panic!("no response for {call_num}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to run an E2E test of a codex-exec call. Starts a wiremock
|
||||
/// server, and returns the response_streams in order for each api call. Runs
|
||||
/// the codex-exec command with the wiremock server as the model server.
|
||||
pub(crate) async fn run_e2e_exec_test(cwd: &Path, response_streams: Vec<String>) {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let num_calls = response_streams.len();
|
||||
let seq_responder = SeqResponder {
|
||||
num_calls: AtomicUsize::new(0),
|
||||
responses: response_streams,
|
||||
};
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(seq_responder)
|
||||
.expect(num_calls as u64)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let cwd = cwd.to_path_buf();
|
||||
let uri = server.uri();
|
||||
Command::cargo_bin("codex-exec")
|
||||
.context("should find binary for codex-exec")
|
||||
.expect("should find binary for codex-exec")
|
||||
.current_dir(cwd.clone())
|
||||
.env("CODEX_HOME", cwd)
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("OPENAI_BASE_URL", format!("{uri}/v1"))
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-s")
|
||||
.arg("danger-full-access")
|
||||
.arg("foo")
|
||||
.assert()
|
||||
.success();
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
// Aggregates all former standalone integration tests as modules.
|
||||
mod apply_patch;
|
||||
mod common;
|
||||
mod output_schema;
|
||||
mod resume;
|
||||
mod sandbox;
|
||||
mod server_error_exit;
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use assert_cmd::prelude::*;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
use serde_json::Value;
|
||||
use std::process::Command;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> {
|
||||
let home = TempDir::new()?;
|
||||
let workspace = TempDir::new()?;
|
||||
let test = test_codex_exec();
|
||||
|
||||
let schema_contents = serde_json::json!({
|
||||
"type": "object",
|
||||
@@ -21,7 +18,7 @@ async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> {
|
||||
"required": ["answer"],
|
||||
"additionalProperties": false
|
||||
});
|
||||
let schema_path = workspace.path().join("schema.json");
|
||||
let schema_path = test.cwd_path().join("schema.json");
|
||||
std::fs::write(&schema_path, serde_json::to_vec_pretty(&schema_contents)?)?;
|
||||
let expected_schema: Value = schema_contents;
|
||||
|
||||
@@ -36,14 +33,11 @@ async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> {
|
||||
]);
|
||||
responses::mount_sse_once(&server, any(), body).await;
|
||||
|
||||
Command::cargo_bin("codex-exec")?
|
||||
.current_dir(workspace.path())
|
||||
.env("CODEX_HOME", home.path())
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("OPENAI_BASE_URL", format!("{}/v1", server.uri()))
|
||||
test.cmd_with_server(&server)
|
||||
.arg("--skip-git-repo-check")
|
||||
// keep using -C in the test to exercise the flag as well
|
||||
.arg("-C")
|
||||
.arg(workspace.path())
|
||||
.arg(test.cwd_path())
|
||||
.arg("--output-schema")
|
||||
.arg(&schema_path)
|
||||
.arg("-m")
|
||||
|
||||
@@ -56,6 +56,7 @@ async fn spawn_command_under_sandbox(
|
||||
|
||||
#[tokio::test]
|
||||
async fn python_multiprocessing_lock_works_under_sandbox() {
|
||||
core_test_support::skip_if_sandbox!();
|
||||
#[cfg(target_os = "macos")]
|
||||
let writable_roots = Vec::<PathBuf>::new();
|
||||
|
||||
@@ -110,6 +111,7 @@ if __name__ == '__main__':
|
||||
|
||||
#[tokio::test]
|
||||
async fn sandbox_distinguishes_command_and_policy_cwds() {
|
||||
core_test_support::skip_if_sandbox!();
|
||||
let temp = tempfile::tempdir().expect("should be able to create temp dir");
|
||||
let sandbox_root = temp.path().join("sandbox");
|
||||
let command_root = temp.path().join("command");
|
||||
|
||||
34
codex-rs/exec/tests/suite/server_error_exit.rs
Normal file
34
codex-rs/exec/tests/suite/server_error_exit.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use core_test_support::responses;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
/// Verify that when the server reports an error, `codex-exec` exits with a
|
||||
/// non-zero status code so automation can detect failures.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exits_non_zero_when_server_reports_error() -> anyhow::Result<()> {
|
||||
let test = test_codex_exec();
|
||||
|
||||
// Mock a simple Responses API SSE stream that immediately reports a
|
||||
// `response.failed` event with an error message.
|
||||
let server = responses::start_mock_server().await;
|
||||
let body = responses::sse(vec![serde_json::json!({
|
||||
"type": "response.failed",
|
||||
"response": {
|
||||
"id": "resp_err_1",
|
||||
"error": {"code": "rate_limit_exceeded", "message": "synthetic server error"}
|
||||
}
|
||||
})]);
|
||||
responses::mount_sse_once(&server, any(), body).await;
|
||||
|
||||
test.cmd_with_server(&server)
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("tell me something")
|
||||
.arg("--experimental-json")
|
||||
.assert()
|
||||
.code(1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -9,7 +9,7 @@ use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use codex_login::ServerOptions;
|
||||
use codex_login::run_login_server;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use tempfile::tempdir;
|
||||
|
||||
// See spawn.rs for details
|
||||
@@ -78,7 +78,7 @@ fn start_mock_issuer() -> (SocketAddr, thread::JoinHandle<()>) {
|
||||
|
||||
#[tokio::test]
|
||||
async fn end_to_end_login_flow_persists_auth_json() -> Result<()> {
|
||||
non_sandbox_test!(result);
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let (issuer_addr, issuer_handle) = start_mock_issuer();
|
||||
let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port());
|
||||
@@ -147,7 +147,7 @@ async fn end_to_end_login_flow_persists_auth_json() -> Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn creates_missing_codex_home_dir() -> Result<()> {
|
||||
non_sandbox_test!(result);
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let (issuer_addr, _issuer_handle) = start_mock_issuer();
|
||||
let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port());
|
||||
@@ -187,7 +187,7 @@ async fn creates_missing_codex_home_dir() -> Result<()> {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn cancels_previous_login_server_when_port_is_in_use() -> Result<()> {
|
||||
non_sandbox_test!(result);
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let (issuer_addr, _issuer_handle) = start_mock_issuer();
|
||||
let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port());
|
||||
|
||||
@@ -70,11 +70,8 @@ async fn main() -> Result<()> {
|
||||
},
|
||||
protocol_version: MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
let initialize_notification_params = None;
|
||||
let timeout = Some(Duration::from_secs(10));
|
||||
let response = client
|
||||
.initialize(params, initialize_notification_params, timeout)
|
||||
.await?;
|
||||
let response = client.initialize(params, timeout).await?;
|
||||
eprintln!("initialize response: {response:?}");
|
||||
|
||||
// Issue `tools/list` request (no params).
|
||||
|
||||
@@ -315,13 +315,12 @@ impl McpClient {
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
initialize_params: InitializeRequestParams,
|
||||
initialize_notification_params: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::InitializeResult> {
|
||||
let response = self
|
||||
.send_request::<InitializeRequest>(initialize_params, timeout)
|
||||
.await?;
|
||||
self.send_notification::<InitializedNotification>(initialize_notification_params)
|
||||
self.send_notification::<InitializedNotification>(None)
|
||||
.await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::MockServer;
|
||||
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_apply_patch_sse_response;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
@@ -308,7 +308,7 @@ async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_codex_tool_passes_base_instructions() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
// Apparently `#[tokio::test]` must return `()`, so we create a helper
|
||||
// function that returns `Result` so we can use `?` in favor of `unwrap`.
|
||||
|
||||
@@ -11,7 +11,7 @@ use codex_protocol::mcp_protocol::NewConversationParams;
|
||||
use codex_protocol::mcp_protocol::NewConversationResponse;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
||||
use codex_protocol::mcp_protocol::SendUserMessageResponse;
|
||||
use core_test_support::non_sandbox_test;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use tempfile::TempDir;
|
||||
@@ -26,7 +26,7 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shell_command_interruption() {
|
||||
non_sandbox_test!();
|
||||
skip_if_no_network!();
|
||||
|
||||
if let Err(err) = shell_command_interruption().await {
|
||||
panic!("failure: {err}");
|
||||
|
||||
@@ -597,16 +597,18 @@ pub struct TokenCountEvent {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct RateLimitSnapshot {
|
||||
/// Percentage (0-100) of the primary window that has been consumed.
|
||||
pub primary_used_percent: f64,
|
||||
/// Percentage (0-100) of the secondary window that has been consumed.
|
||||
pub secondary_used_percent: f64,
|
||||
/// Size of the primary window relative to secondary (0-100).
|
||||
pub primary_to_secondary_ratio_percent: f64,
|
||||
/// Rolling window duration for the primary limit, in minutes.
|
||||
pub primary_window_minutes: u64,
|
||||
/// Rolling window duration for the secondary limit, in minutes.
|
||||
pub secondary_window_minutes: u64,
|
||||
pub primary: Option<RateLimitWindow>,
|
||||
pub secondary: Option<RateLimitWindow>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
pub struct RateLimitWindow {
|
||||
/// Percentage (0-100) of the window that has been consumed.
|
||||
pub used_percent: f64,
|
||||
/// Rolling window duration, in minutes.
|
||||
pub window_minutes: Option<u64>,
|
||||
/// Seconds until the window resets.
|
||||
pub resets_in_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
// Includes prompts, tools and space to call compact.
|
||||
|
||||
27
codex-rs/responses-api-proxy/Cargo.toml
Normal file
27
codex-rs/responses-api-proxy/Cargo.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-responses-api-proxy"
|
||||
version = { workspace = true }
|
||||
|
||||
[lib]
|
||||
name = "codex_responses_api_proxy"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "responses-api-proxy"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-arg0 = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["blocking", "json", "rustls-tls"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tiny_http = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
zeroize = { workspace = true }
|
||||
53
codex-rs/responses-api-proxy/README.md
Normal file
53
codex-rs/responses-api-proxy/README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# codex-responses-api-proxy
|
||||
|
||||
A strict HTTP proxy that only forwards `POST` requests to `/v1/responses` to the OpenAI API (`https://api.openai.com`), injecting the `Authorization: Bearer $OPENAI_API_KEY` header. Everything else is rejected with `403 Forbidden`.
|
||||
|
||||
## Expected Usage
|
||||
|
||||
**IMPORTANT:** This is designed to be used with `CODEX_SECURE_MODE=1` so that an unprivileged user cannot inspect or tamper with this process. Though if `--http-shutdown` is specified, an unprivileged user _can_ shutdown the server.
|
||||
|
||||
A privileged user (i.e., `root` or a user with `sudo`) who has access to `OPENAI_API_KEY` would run the following to start the server:
|
||||
|
||||
```shell
|
||||
printenv OPENAI_API_KEY | CODEX_SECURE_MODE=1 codex responses-api-proxy --http-shutdown --server-info /tmp/server-info.json
|
||||
```
|
||||
|
||||
A non-privileged user would then run Codex as follows, specifying the `model_provider` dynamically:
|
||||
|
||||
```shell
|
||||
PROXY_PORT=$(jq .port /tmp/server-info.json)
|
||||
PROXY_BASE_URL="http://127.0.0.1:${PROXY_PORT}"
|
||||
codex exec -c "model_providers.openai-proxy={ name = 'OpenAI Proxy', base_url = '${PROXY_BASE_URL}/v1', wire_api='responses' }" \
|
||||
-c model_provider="openai-proxy" \
|
||||
'Your prompt here'
|
||||
```
|
||||
|
||||
When the unprivileged user was finished, they could shutdown the server using `curl` (since `kill -9` is not an option):
|
||||
|
||||
```shell
|
||||
curl --fail --silent --show-error "${PROXY_BASE_URL}/shutdown"
|
||||
```
|
||||
|
||||
## Behavior
|
||||
|
||||
- Reads the API key from `stdin`. All callers should pipe the key in (for example, `printenv OPENAI_API_KEY | codex responses-api-proxy`).
|
||||
- Formats the header value as `Bearer <key>` and attempts to `mlock(2)` the memory holding that header so it is not swapped to disk.
|
||||
- Listens on the provided port or an ephemeral port if `--port` is not specified.
|
||||
- Accepts exactly `POST /v1/responses` (no query string). The request body is forwarded to `https://api.openai.com/v1/responses` with `Authorization: Bearer <key>` set. All original request headers (except any incoming `Authorization`) are forwarded upstream. For other requests, it responds with `403`.
|
||||
- Optionally writes a single-line JSON file with server info, currently `{ "port": <u16> }`.
|
||||
- Optional `--http-shutdown` enables `GET /shutdown` to terminate the process with exit code 0. This allows one user (e.g., root) to start the proxy and another unprivileged user on the host to shut it down.
|
||||
|
||||
## CLI
|
||||
|
||||
```
|
||||
responses-api-proxy [--port <PORT>] [--server-info <FILE>] [--http-shutdown]
|
||||
```
|
||||
|
||||
- `--port <PORT>`: Port to bind on `127.0.0.1`. If omitted, an ephemeral port is chosen.
|
||||
- `--server-info <FILE>`: If set, the proxy writes a single line of JSON with `{ "port": <PORT> }` once listening.
|
||||
- `--http-shutdown`: If set, enables `GET /shutdown` to exit the process with code `0`.
|
||||
|
||||
## Notes
|
||||
|
||||
- Only `POST /v1/responses` is permitted. No query strings are allowed.
|
||||
- All request headers are forwarded to the upstream call (aside from overriding `Authorization`). Response status and content-type are mirrored from upstream.
|
||||
202
codex-rs/responses-api-proxy/src/lib.rs
Normal file
202
codex-rs/responses-api-proxy/src/lib.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Write;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpListener;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use clap::Parser;
|
||||
use reqwest::blocking::Client;
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use reqwest::header::HOST;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
use serde::Serialize;
|
||||
use tiny_http::Header;
|
||||
use tiny_http::Method;
|
||||
use tiny_http::Request;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
use tiny_http::StatusCode;
|
||||
|
||||
mod read_api_key;
|
||||
use read_api_key::read_auth_header_from_stdin;
|
||||
|
||||
/// CLI arguments for the proxy.
|
||||
#[derive(Debug, Clone, Parser)]
|
||||
#[command(name = "responses-api-proxy", about = "Minimal OpenAI responses proxy")]
|
||||
pub struct Args {
|
||||
/// Port to listen on. If not set, an ephemeral port is used.
|
||||
#[arg(long)]
|
||||
pub port: Option<u16>,
|
||||
|
||||
/// Path to a JSON file to write startup info (single line). Includes {"port": <u16>}.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
pub server_info: Option<PathBuf>,
|
||||
|
||||
/// Enable HTTP shutdown endpoint at GET /shutdown
|
||||
#[arg(long)]
|
||||
pub http_shutdown: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ServerInfo {
|
||||
port: u16,
|
||||
}
|
||||
|
||||
/// Entry point for the library main, for parity with other crates.
|
||||
pub fn run_main(args: Args) -> Result<()> {
|
||||
let auth_header = read_auth_header_from_stdin()?;
|
||||
|
||||
let (listener, bound_addr) = bind_listener(args.port)?;
|
||||
if let Some(path) = args.server_info.as_ref() {
|
||||
write_server_info(path, bound_addr.port())?;
|
||||
}
|
||||
let server = Server::from_listener(listener, None)
|
||||
.map_err(|err| anyhow!("creating HTTP server: {err}"))?;
|
||||
let client = Arc::new(
|
||||
Client::builder()
|
||||
.build()
|
||||
.context("building reqwest client")?,
|
||||
);
|
||||
|
||||
eprintln!("responses-api-proxy listening on {bound_addr}");
|
||||
|
||||
let http_shutdown = args.http_shutdown;
|
||||
for request in server.incoming_requests() {
|
||||
let client = client.clone();
|
||||
std::thread::spawn(move || {
|
||||
if http_shutdown && request.method() == &Method::Get && request.url() == "/shutdown" {
|
||||
let _ = request.respond(Response::new_empty(StatusCode(200)));
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
if let Err(e) = forward_request(&client, auth_header, request) {
|
||||
eprintln!("forwarding error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Err(anyhow!("server stopped unexpectedly"))
|
||||
}
|
||||
|
||||
fn bind_listener(port: Option<u16>) -> Result<(TcpListener, SocketAddr)> {
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], port.unwrap_or(0)));
|
||||
let listener = TcpListener::bind(addr).with_context(|| format!("failed to bind {addr}"))?;
|
||||
let bound = listener.local_addr().context("failed to read local_addr")?;
|
||||
Ok((listener, bound))
|
||||
}
|
||||
|
||||
fn write_server_info(path: &Path, port: u16) -> Result<()> {
|
||||
if let Some(parent) = path.parent()
|
||||
&& !parent.as_os_str().is_empty()
|
||||
{
|
||||
let parent_display = parent.display();
|
||||
fs::create_dir_all(parent).with_context(|| format!("create_dir_all {parent_display}"))?;
|
||||
}
|
||||
let info = ServerInfo { port };
|
||||
let data = serde_json::to_vec(&info).context("serialize startup info")?;
|
||||
let p = path.display();
|
||||
let mut f = File::create(path).with_context(|| format!("create {p}"))?;
|
||||
f.write_all(&data).with_context(|| format!("write {p}"))?;
|
||||
f.write_all(b"\n").with_context(|| format!("newline {p}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn forward_request(client: &Client, auth_header: &'static str, mut req: Request) -> Result<()> {
|
||||
// Only allow POST /v1/responses exactly, no query string.
|
||||
let method = req.method().clone();
|
||||
let url_path = req.url().to_string();
|
||||
let allow = method == Method::Post && url_path == "/v1/responses";
|
||||
|
||||
if !allow {
|
||||
let resp = Response::new_empty(StatusCode(403));
|
||||
let _ = req.respond(resp);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read request body
|
||||
let mut body = Vec::new();
|
||||
let mut reader = req.as_reader();
|
||||
std::io::Read::read_to_end(&mut reader, &mut body)?;
|
||||
|
||||
// Build headers for upstream, forwarding everything from the incoming
|
||||
// request except Authorization (we replace it below).
|
||||
let mut headers = HeaderMap::new();
|
||||
for header in req.headers() {
|
||||
let name_ascii = header.field.as_str();
|
||||
let lower = name_ascii.to_ascii_lowercase();
|
||||
if lower.as_str() == "authorization" || lower.as_str() == "host" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let header_name = match HeaderName::from_bytes(lower.as_bytes()) {
|
||||
Ok(name) => name,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if let Ok(value) = HeaderValue::from_bytes(header.value.as_bytes()) {
|
||||
headers.append(header_name, value);
|
||||
}
|
||||
}
|
||||
|
||||
// As part of our effort to to keep `auth_header` secret, we use a
|
||||
// combination of `from_static()` and `set_sensitive(true)`.
|
||||
let mut auth_header_value = HeaderValue::from_static(auth_header);
|
||||
auth_header_value.set_sensitive(true);
|
||||
headers.insert(AUTHORIZATION, auth_header_value);
|
||||
|
||||
headers.insert(HOST, HeaderValue::from_static("api.openai.com"));
|
||||
|
||||
let upstream = "https://api.openai.com/v1/responses";
|
||||
let upstream_resp = client
|
||||
.post(upstream)
|
||||
.headers(headers)
|
||||
.body(body)
|
||||
.send()
|
||||
.context("forwarding request to upstream")?;
|
||||
|
||||
// We have to create an adapter between a `reqwest::blocking::Response`
|
||||
// and a `tiny_http::Response`. Fortunately, `reqwest::blocking::Response`
|
||||
// implements `Read`, so we can use it directly as the body of the
|
||||
// `tiny_http::Response`.
|
||||
let status = upstream_resp.status();
|
||||
let mut response_headers = Vec::new();
|
||||
for (name, value) in upstream_resp.headers().iter() {
|
||||
// Skip headers that tiny_http manages itself.
|
||||
if matches!(
|
||||
name.as_str(),
|
||||
"content-length" | "transfer-encoding" | "connection" | "trailer" | "upgrade"
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(header) = Header::from_bytes(name.as_str().as_bytes(), value.as_bytes()) {
|
||||
response_headers.push(header);
|
||||
}
|
||||
}
|
||||
|
||||
let content_length = upstream_resp.content_length().and_then(|len| {
|
||||
if len <= usize::MAX as u64 {
|
||||
Some(len as usize)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let response = Response::new(
|
||||
StatusCode(status.as_u16()),
|
||||
response_headers,
|
||||
upstream_resp,
|
||||
content_length,
|
||||
None,
|
||||
);
|
||||
|
||||
let _ = req.respond(response);
|
||||
Ok(())
|
||||
}
|
||||
14
codex-rs/responses-api-proxy/src/main.rs
Normal file
14
codex-rs/responses-api-proxy/src/main.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use codex_arg0::arg0_dispatch_or_else;
|
||||
use codex_responses_api_proxy::Args as ResponsesApiProxyArgs;
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|_codex_linux_sandbox_exe| async move {
|
||||
let args = ResponsesApiProxyArgs::parse();
|
||||
tokio::task::spawn_blocking(move || codex_responses_api_proxy::run_main(args))
|
||||
.await
|
||||
.context("responses-api-proxy blocking task panicked")??;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
185
codex-rs/responses-api-proxy/src/read_api_key.rs
Normal file
185
codex-rs/responses-api-proxy/src/read_api_key.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use std::io::Read;
|
||||
use zeroize::Zeroize;
|
||||
|
||||
/// Use a generous buffer size to avoid truncation and to allow for longer API
|
||||
/// keys in the future.
|
||||
const BUFFER_SIZE: usize = 1024;
|
||||
const AUTH_HEADER_PREFIX: &[u8] = b"Bearer ";
|
||||
|
||||
/// Reads the auth token from stdin and returns a static `Authorization` header
|
||||
/// value with the auth token used with `Bearer`. The header value is returned
|
||||
/// as a `&'static str` whose bytes are locked in memory to avoid accidental
|
||||
/// exposure.
|
||||
pub(crate) fn read_auth_header_from_stdin() -> Result<&'static str> {
|
||||
read_auth_header_with(|buffer| std::io::stdin().read(buffer))
|
||||
}
|
||||
|
||||
fn read_auth_header_with<F>(read_fn: F) -> Result<&'static str>
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> std::io::Result<usize>,
|
||||
{
|
||||
// TAKE CARE WHEN MODIFYING THIS CODE!!!
|
||||
//
|
||||
// This function goes to great lengths to avoid leaving the API key in
|
||||
// memory longer than necessary and to avoid copying it around. We read
|
||||
// directly into a stack buffer so the only heap allocation should be the
|
||||
// one to create the String (with the exact size) for the header value,
|
||||
// which we then immediately protect with mlock(2).
|
||||
let mut buf = [0u8; BUFFER_SIZE];
|
||||
buf[..AUTH_HEADER_PREFIX.len()].copy_from_slice(AUTH_HEADER_PREFIX);
|
||||
|
||||
let read = read_fn(&mut buf[AUTH_HEADER_PREFIX.len()..]).inspect_err(|_err| {
|
||||
buf.zeroize();
|
||||
})?;
|
||||
|
||||
if read == buf.len() - AUTH_HEADER_PREFIX.len() {
|
||||
buf.zeroize();
|
||||
return Err(anyhow!(
|
||||
"OPENAI_API_KEY is too large to fit in the 512-byte buffer"
|
||||
));
|
||||
}
|
||||
|
||||
let mut total = AUTH_HEADER_PREFIX.len() + read;
|
||||
while total > AUTH_HEADER_PREFIX.len() && (buf[total - 1] == b'\n' || buf[total - 1] == b'\r') {
|
||||
total -= 1;
|
||||
}
|
||||
|
||||
if total == AUTH_HEADER_PREFIX.len() {
|
||||
buf.zeroize();
|
||||
return Err(anyhow!(
|
||||
"OPENAI_API_KEY must be provided via stdin (e.g. printenv OPENAI_API_KEY | codex responses-api-proxy)"
|
||||
));
|
||||
}
|
||||
|
||||
let header_str = match std::str::from_utf8(&buf[..total]) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
buf.zeroize();
|
||||
return Err(err).context("reading Authorization header from stdin as UTF-8");
|
||||
}
|
||||
};
|
||||
|
||||
let header_value = String::from(header_str);
|
||||
buf.zeroize();
|
||||
|
||||
let leaked: &'static mut str = header_value.leak();
|
||||
mlock_str(leaked);
|
||||
|
||||
Ok(leaked)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn mlock_str(value: &str) {
|
||||
use libc::_SC_PAGESIZE;
|
||||
use libc::c_void;
|
||||
use libc::mlock;
|
||||
use libc::sysconf;
|
||||
|
||||
if value.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let page_size = unsafe { sysconf(_SC_PAGESIZE) };
|
||||
if page_size <= 0 {
|
||||
return;
|
||||
}
|
||||
let page_size = page_size as usize;
|
||||
if page_size == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let addr = value.as_ptr() as usize;
|
||||
let len = value.len();
|
||||
let start = addr & !(page_size - 1);
|
||||
let addr_end = match addr.checked_add(len) {
|
||||
Some(v) => match v.checked_add(page_size - 1) {
|
||||
Some(total) => total,
|
||||
None => return,
|
||||
},
|
||||
None => return,
|
||||
};
|
||||
let end = addr_end & !(page_size - 1);
|
||||
let size = end.saturating_sub(start);
|
||||
if size == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = unsafe { mlock(start as *const c_void, size) };
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn mlock_str(_value: &str) {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io;
|
||||
|
||||
#[test]
|
||||
fn reads_key_with_no_newlines() {
|
||||
let result = read_auth_header_with(|buf| {
|
||||
let data = b"sk-abc123";
|
||||
buf[..data.len()].copy_from_slice(data);
|
||||
Ok(data.len())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "Bearer sk-abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reads_key_and_trims_newlines() {
|
||||
let result = read_auth_header_with(|buf| {
|
||||
let data = b"sk-abc123\r\n";
|
||||
buf[..data.len()].copy_from_slice(data);
|
||||
Ok(data.len())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "Bearer sk-abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn errors_when_no_input_provided() {
|
||||
let err = read_auth_header_with(|_| Ok(0)).unwrap_err();
|
||||
let message = format!("{err:#}");
|
||||
assert!(message.contains("must be provided"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn errors_when_buffer_filled() {
|
||||
let err = read_auth_header_with(|buf| {
|
||||
let data = vec![b'a'; BUFFER_SIZE - AUTH_HEADER_PREFIX.len()];
|
||||
buf[..data.len()].copy_from_slice(&data);
|
||||
Ok(data.len())
|
||||
})
|
||||
.unwrap_err();
|
||||
let message = format!("{err:#}");
|
||||
assert!(message.contains("too large"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn propagates_io_error() {
|
||||
let err = read_auth_header_with(|_| Err(io::Error::other("boom"))).unwrap_err();
|
||||
|
||||
let io_error = err.downcast_ref::<io::Error>().unwrap();
|
||||
assert_eq!(io_error.kind(), io::ErrorKind::Other);
|
||||
assert_eq!(io_error.to_string(), "boom");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn errors_on_invalid_utf8() {
|
||||
let err = read_auth_header_with(|buf| {
|
||||
let data = b"sk-abc\xff";
|
||||
buf[..data.len()].copy_from_slice(data);
|
||||
Ok(data.len())
|
||||
})
|
||||
.unwrap_err();
|
||||
|
||||
let message = format!("{err:#}");
|
||||
assert!(message.contains("UTF-8"));
|
||||
}
|
||||
}
|
||||
34
codex-rs/rmcp-client/Cargo.toml
Normal file
34
codex-rs/rmcp-client/Cargo.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "codex-rmcp-client"
|
||||
version = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
rmcp = { version = "0.7.0", default-features = false, features = [
|
||||
"base64",
|
||||
"client",
|
||||
"macros",
|
||||
"schemars",
|
||||
"server",
|
||||
"transport-child-process",
|
||||
] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = [
|
||||
"io-util",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"sync",
|
||||
"io-std",
|
||||
"time",
|
||||
] }
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
142
codex-rs/rmcp-client/src/bin/rmcp_test_server.rs
Normal file
142
codex-rs/rmcp-client/src/bin/rmcp_test_server.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::ServiceExt;
|
||||
use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::JsonObject;
|
||||
use rmcp::model::ListToolsResult;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::model::ServerCapabilities;
|
||||
use rmcp::model::ServerInfo;
|
||||
use rmcp::model::Tool;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::task;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestToolServer {
|
||||
tools: Arc<Vec<Tool>>,
|
||||
}
|
||||
pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) {
|
||||
(tokio::io::stdin(), tokio::io::stdout())
|
||||
}
|
||||
impl TestToolServer {
|
||||
fn new() -> Self {
|
||||
let tools = vec![Self::echo_tool()];
|
||||
Self {
|
||||
tools: Arc::new(tools),
|
||||
}
|
||||
}
|
||||
|
||||
fn echo_tool() -> Tool {
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema: JsonObject = serde_json::from_value(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" },
|
||||
"env_var": { "type": "string" }
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": false
|
||||
}))
|
||||
.expect("echo tool schema should deserialize");
|
||||
|
||||
Tool::new(
|
||||
Cow::Borrowed("echo"),
|
||||
Cow::Borrowed("Echo back the provided message and include environment data."),
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EchoArgs {
|
||||
message: String,
|
||||
#[allow(dead_code)]
|
||||
env_var: Option<String>,
|
||||
}
|
||||
|
||||
impl ServerHandler for TestToolServer {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_tools()
|
||||
.enable_tool_list_changed()
|
||||
.build(),
|
||||
..ServerInfo::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn list_tools(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> impl std::future::Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
|
||||
let tools = self.tools.clone();
|
||||
async move {
|
||||
Ok(ListToolsResult {
|
||||
tools: (*tools).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
request: CallToolRequestParam,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match request.name.as_ref() {
|
||||
"echo" => {
|
||||
let args: EchoArgs = match request.arguments {
|
||||
Some(arguments) => serde_json::from_value(serde_json::Value::Object(
|
||||
arguments.into_iter().collect(),
|
||||
))
|
||||
.map_err(|err| McpError::invalid_params(err.to_string(), None))?,
|
||||
None => {
|
||||
return Err(McpError::invalid_params(
|
||||
"missing arguments for echo tool",
|
||||
None,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let env_snapshot: HashMap<String, String> = std::env::vars().collect();
|
||||
let structured_content = json!({
|
||||
"echo": args.message,
|
||||
"env": env_snapshot.get("MCP_TEST_VALUE"),
|
||||
});
|
||||
|
||||
Ok(CallToolResult {
|
||||
content: Vec::new(),
|
||||
structured_content: Some(structured_content),
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
other => Err(McpError::invalid_params(
|
||||
format!("unknown tool: {other}"),
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
eprintln!("starting rmcp test server");
|
||||
// Run the server with STDIO transport. If the client disconnects we simply
|
||||
// bubble up the error so the process exits.
|
||||
let service = TestToolServer::new();
|
||||
let running = service.serve(stdio()).await?;
|
||||
|
||||
// Wait for the client to finish interacting with the server.
|
||||
running.waiting().await?;
|
||||
// Drain background tasks to ensure clean shutdown.
|
||||
task::yield_now().await;
|
||||
Ok(())
|
||||
}
|
||||
5
codex-rs/rmcp-client/src/lib.rs
Normal file
5
codex-rs/rmcp-client/src/lib.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod logging_client_handler;
|
||||
mod rmcp_client;
|
||||
mod utils;
|
||||
|
||||
pub use rmcp_client::RmcpClient;
|
||||
134
codex-rs/rmcp-client/src/logging_client_handler.rs
Normal file
134
codex-rs/rmcp-client/src/logging_client_handler.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use rmcp::ClientHandler;
|
||||
use rmcp::RoleClient;
|
||||
use rmcp::model::CancelledNotificationParam;
|
||||
use rmcp::model::ClientInfo;
|
||||
use rmcp::model::CreateElicitationRequestParam;
|
||||
use rmcp::model::CreateElicitationResult;
|
||||
use rmcp::model::ElicitationAction;
|
||||
use rmcp::model::LoggingLevel;
|
||||
use rmcp::model::LoggingMessageNotificationParam;
|
||||
use rmcp::model::ProgressNotificationParam;
|
||||
use rmcp::model::ResourceUpdatedNotificationParam;
|
||||
use rmcp::service::NotificationContext;
|
||||
use rmcp::service::RequestContext;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct LoggingClientHandler {
|
||||
client_info: ClientInfo,
|
||||
}
|
||||
|
||||
impl LoggingClientHandler {
|
||||
pub(crate) fn new(client_info: ClientInfo) -> Self {
|
||||
Self { client_info }
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientHandler for LoggingClientHandler {
|
||||
// TODO (CODEX-3571): support elicitations.
|
||||
async fn create_elicitation(
|
||||
&self,
|
||||
request: CreateElicitationRequestParam,
|
||||
_context: RequestContext<RoleClient>,
|
||||
) -> Result<CreateElicitationResult, rmcp::ErrorData> {
|
||||
info!(
|
||||
"MCP server requested elicitation ({}). Elicitations are not supported yet. Declining.",
|
||||
request.message
|
||||
);
|
||||
Ok(CreateElicitationResult {
|
||||
action: ElicitationAction::Decline,
|
||||
content: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn on_cancelled(
|
||||
&self,
|
||||
params: CancelledNotificationParam,
|
||||
_context: NotificationContext<RoleClient>,
|
||||
) {
|
||||
info!(
|
||||
"MCP server cancelled request (request_id: {}, reason: {:?})",
|
||||
params.request_id, params.reason
|
||||
);
|
||||
}
|
||||
|
||||
async fn on_progress(
|
||||
&self,
|
||||
params: ProgressNotificationParam,
|
||||
_context: NotificationContext<RoleClient>,
|
||||
) {
|
||||
info!(
|
||||
"MCP server progress notification (token: {:?}, progress: {}, total: {:?}, message: {:?})",
|
||||
params.progress_token, params.progress, params.total, params.message
|
||||
);
|
||||
}
|
||||
|
||||
async fn on_resource_updated(
|
||||
&self,
|
||||
params: ResourceUpdatedNotificationParam,
|
||||
_context: NotificationContext<RoleClient>,
|
||||
) {
|
||||
info!("MCP server resource updated (uri: {})", params.uri);
|
||||
}
|
||||
|
||||
async fn on_resource_list_changed(&self, _context: NotificationContext<RoleClient>) {
|
||||
info!("MCP server resource list changed");
|
||||
}
|
||||
|
||||
async fn on_tool_list_changed(&self, _context: NotificationContext<RoleClient>) {
|
||||
info!("MCP server tool list changed");
|
||||
}
|
||||
|
||||
async fn on_prompt_list_changed(&self, _context: NotificationContext<RoleClient>) {
|
||||
info!("MCP server prompt list changed");
|
||||
}
|
||||
|
||||
fn get_info(&self) -> ClientInfo {
|
||||
self.client_info.clone()
|
||||
}
|
||||
|
||||
async fn on_logging_message(
|
||||
&self,
|
||||
params: LoggingMessageNotificationParam,
|
||||
_context: NotificationContext<RoleClient>,
|
||||
) {
|
||||
let LoggingMessageNotificationParam {
|
||||
level,
|
||||
logger,
|
||||
data,
|
||||
} = params;
|
||||
let logger = logger.as_deref();
|
||||
match level {
|
||||
LoggingLevel::Emergency
|
||||
| LoggingLevel::Alert
|
||||
| LoggingLevel::Critical
|
||||
| LoggingLevel::Error => {
|
||||
error!(
|
||||
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
|
||||
level, logger, data
|
||||
);
|
||||
}
|
||||
LoggingLevel::Warning => {
|
||||
warn!(
|
||||
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
|
||||
level, logger, data
|
||||
);
|
||||
}
|
||||
LoggingLevel::Notice | LoggingLevel::Info => {
|
||||
info!(
|
||||
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
|
||||
level, logger, data
|
||||
);
|
||||
}
|
||||
LoggingLevel::Debug => {
|
||||
debug!(
|
||||
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
|
||||
level, logger, data
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
183
codex-rs/rmcp-client/src/rmcp_client.rs
Normal file
183
codex-rs/rmcp-client/src/rmcp_client.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::io;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::InitializeResult;
|
||||
use mcp_types::ListToolsRequestParams;
|
||||
use mcp_types::ListToolsResult;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::InitializeRequestParam;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::service::RoleClient;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::service::{self};
|
||||
use rmcp::transport::child_process::TokioChildProcess;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
use crate::utils::convert_to_mcp;
|
||||
use crate::utils::convert_to_rmcp;
|
||||
use crate::utils::create_env_for_mcp_server;
|
||||
use crate::utils::run_with_timeout;
|
||||
|
||||
enum ClientState {
|
||||
Connecting {
|
||||
transport: Option<TokioChildProcess>,
|
||||
},
|
||||
Ready {
|
||||
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
|
||||
},
|
||||
}
|
||||
|
||||
/// MCP client implemented on top of the official `rmcp` SDK.
|
||||
/// https://github.com/modelcontextprotocol/rust-sdk
|
||||
pub struct RmcpClient {
|
||||
state: Mutex<ClientState>,
|
||||
}
|
||||
|
||||
impl RmcpClient {
|
||||
pub async fn new_stdio_client(
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
) -> io::Result<Self> {
|
||||
let program_name = program.to_string_lossy().into_owned();
|
||||
let mut command = Command::new(&program);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.env_clear()
|
||||
.envs(create_env_for_mcp_server(env))
|
||||
.args(&args);
|
||||
|
||||
let (transport, stderr) = TokioChildProcess::builder(command)
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
if let Some(stderr) = stderr {
|
||||
tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stderr).lines();
|
||||
loop {
|
||||
match reader.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
info!("MCP server stderr ({program_name}): {line}");
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(error) => {
|
||||
warn!("Failed to read MCP server stderr ({program_name}): {error}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(transport),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform the initialization handshake with the MCP server.
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
params: InitializeRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<InitializeResult> {
|
||||
let transport = {
|
||||
let mut guard = self.state.lock().await;
|
||||
match &mut *guard {
|
||||
ClientState::Connecting { transport } => transport
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("client already initializing"))?,
|
||||
ClientState::Ready { .. } => {
|
||||
return Err(anyhow!("client already initialized"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?;
|
||||
let client_handler = LoggingClientHandler::new(client_info);
|
||||
let service_future = service::serve_client(client_handler, transport);
|
||||
|
||||
let service = match timeout {
|
||||
Some(duration) => time::timeout(duration, service_future)
|
||||
.await
|
||||
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
None => service_future
|
||||
.await
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
};
|
||||
|
||||
let initialize_result_rmcp = service
|
||||
.peer()
|
||||
.peer_info()
|
||||
.ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?;
|
||||
let initialize_result = convert_to_mcp(initialize_result_rmcp)?;
|
||||
|
||||
{
|
||||
let mut guard = self.state.lock().await;
|
||||
*guard = ClientState::Ready {
|
||||
service: Arc::new(service),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(initialize_result)
|
||||
}
|
||||
|
||||
pub async fn list_tools(
|
||||
&self,
|
||||
params: Option<ListToolsRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListToolsResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
.transpose()?;
|
||||
|
||||
let fut = service.list_tools(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "tools/list").await?;
|
||||
convert_to_mcp(result)
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
name: String,
|
||||
arguments: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<CallToolResult> {
|
||||
let service = self.service().await?;
|
||||
let params = CallToolRequestParams { arguments, name };
|
||||
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
|
||||
let fut = service.call_tool(rmcp_params);
|
||||
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
|
||||
convert_call_tool_result(rmcp_result)
|
||||
}
|
||||
|
||||
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
|
||||
let guard = self.state.lock().await;
|
||||
match &*guard {
|
||||
ClientState::Ready { service } => Ok(Arc::clone(service)),
|
||||
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
|
||||
}
|
||||
}
|
||||
}
|
||||
160
codex-rs/rmcp-client/src/utils.rs
Normal file
160
codex-rs/rmcp-client/src/utils.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use mcp_types::CallToolResult;
|
||||
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
||||
use rmcp::service::ServiceError;
|
||||
use serde_json::Value;
|
||||
use tokio::time;
|
||||
|
||||
pub(crate) async fn run_with_timeout<F, T>(
|
||||
fut: F,
|
||||
timeout: Option<Duration>,
|
||||
label: &str,
|
||||
) -> Result<T>
|
||||
where
|
||||
F: std::future::Future<Output = Result<T, ServiceError>>,
|
||||
{
|
||||
if let Some(duration) = timeout {
|
||||
let result = time::timeout(duration, fut)
|
||||
.await
|
||||
.with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?;
|
||||
result.map_err(|err| anyhow!("{label} failed: {err}"))
|
||||
} else {
|
||||
fut.await.map_err(|err| anyhow!("{label} failed: {err}"))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn convert_call_tool_result(result: RmcpCallToolResult) -> Result<CallToolResult> {
|
||||
let mut value = serde_json::to_value(result)?;
|
||||
if let Some(obj) = value.as_object_mut()
|
||||
&& (obj.get("content").is_none()
|
||||
|| obj.get("content").is_some_and(serde_json::Value::is_null))
|
||||
{
|
||||
obj.insert("content".to_string(), Value::Array(Vec::new()));
|
||||
}
|
||||
serde_json::from_value(value).context("failed to convert call tool result")
|
||||
}
|
||||
|
||||
/// Convert from mcp-types to Rust SDK types.
|
||||
///
|
||||
/// The Rust SDK types are the same as our mcp-types crate because they are both
|
||||
/// derived from the same MCP specification.
|
||||
/// As a result, it should be safe to convert directly from one to the other.
|
||||
pub(crate) fn convert_to_rmcp<T, U>(value: T) -> Result<U>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
U: serde::de::DeserializeOwned,
|
||||
{
|
||||
let json = serde_json::to_value(value)?;
|
||||
serde_json::from_value(json).map_err(|err| anyhow!(err))
|
||||
}
|
||||
|
||||
/// Convert from Rust SDK types to mcp-types.
|
||||
///
|
||||
/// The Rust SDK types are the same as our mcp-types crate because they are both
|
||||
/// derived from the same MCP specification.
|
||||
/// As a result, it should be safe to convert directly from one to the other.
|
||||
pub(crate) fn convert_to_mcp<T, U>(value: T) -> Result<U>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
U: serde::de::DeserializeOwned,
|
||||
{
|
||||
let json = serde_json::to_value(value)?;
|
||||
serde_json::from_value(json).map_err(|err| anyhow!(err))
|
||||
}
|
||||
|
||||
pub(crate) fn create_env_for_mcp_server(
|
||||
extra_env: Option<HashMap<String, String>>,
|
||||
) -> HashMap<String, String> {
|
||||
DEFAULT_ENV_VARS
|
||||
.iter()
|
||||
.filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value)))
|
||||
.chain(extra_env.unwrap_or_default())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
"HOME",
|
||||
"LOGNAME",
|
||||
"PATH",
|
||||
"SHELL",
|
||||
"USER",
|
||||
"__CF_USER_TEXT_ENCODING",
|
||||
"LANG",
|
||||
"LC_ALL",
|
||||
"TERM",
|
||||
"TMPDIR",
|
||||
"TZ",
|
||||
];
|
||||
|
||||
#[cfg(windows)]
|
||||
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
"PATH",
|
||||
"PATHEXT",
|
||||
"USERNAME",
|
||||
"USERDOMAIN",
|
||||
"USERPROFILE",
|
||||
"TEMP",
|
||||
"TMP",
|
||||
];
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ContentBlock;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_env_honors_overrides() {
|
||||
let value = "custom".to_string();
|
||||
let env = create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])));
|
||||
assert_eq!(env.get("TZ"), Some(&value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_call_tool_result_defaults_missing_content() -> Result<()> {
|
||||
let structured_content = json!({ "key": "value" });
|
||||
let rmcp_result = RmcpCallToolResult {
|
||||
content: vec![],
|
||||
structured_content: Some(structured_content.clone()),
|
||||
is_error: Some(true),
|
||||
meta: None,
|
||||
};
|
||||
|
||||
let result = convert_call_tool_result(rmcp_result)?;
|
||||
|
||||
assert!(result.content.is_empty());
|
||||
assert_eq!(result.structured_content, Some(structured_content));
|
||||
assert_eq!(result.is_error, Some(true));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_call_tool_result_preserves_existing_content() -> Result<()> {
|
||||
let rmcp_result = RmcpCallToolResult::success(vec![rmcp::model::Content::text("hello")]);
|
||||
|
||||
let result = convert_call_tool_result(rmcp_result)?;
|
||||
|
||||
assert_eq!(result.content.len(), 1);
|
||||
match &result.content[0] {
|
||||
ContentBlock::TextContent(text_content) => {
|
||||
assert_eq!(text_content.text, "hello");
|
||||
assert_eq!(text_content.r#type, "text");
|
||||
}
|
||||
other => panic!("expected text content got {other:?}"),
|
||||
}
|
||||
assert_eq!(result.structured_content, None);
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "1.89.0"
|
||||
components = [ "clippy", "rustfmt", "rust-src"]
|
||||
channel = "1.90.0"
|
||||
components = ["clippy", "rustfmt", "rust-src"]
|
||||
|
||||
@@ -40,24 +40,20 @@ codex-login = { workspace = true }
|
||||
codex-ollama = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
color-eyre = { workspace = true }
|
||||
crossterm = { workspace = true, features = [
|
||||
"bracketed-paste",
|
||||
"event-stream",
|
||||
] }
|
||||
dirs = { workspace = true }
|
||||
crossterm = { workspace = true, features = ["bracketed-paste", "event-stream"] }
|
||||
diffy = { workspace = true }
|
||||
image = { workspace = true, features = [
|
||||
"jpeg",
|
||||
"png",
|
||||
] }
|
||||
dirs = { workspace = true }
|
||||
image = { workspace = true, features = ["jpeg", "png"] }
|
||||
itertools = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
mcp-types = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
path-clean = { workspace = true }
|
||||
pathdiff = { workspace = true }
|
||||
pulldown-cmark = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
ratatui = { workspace = true, features = [
|
||||
"scrolling-regions",
|
||||
"unstable-backend-writer",
|
||||
"unstable-rendered-line-info",
|
||||
"unstable-widget-ref",
|
||||
] }
|
||||
@@ -81,11 +77,9 @@ tokio-stream = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-appender = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
pulldown-cmark = { workspace = true }
|
||||
unicode-segmentation = { workspace = true }
|
||||
unicode-width = { workspace = true }
|
||||
url = { workspace = true }
|
||||
pathdiff = { workspace = true }
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
libc = { workspace = true }
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user