mirror of
https://github.com/openai/codex.git
synced 2026-02-08 01:43:46 +00:00
Compare commits
62 Commits
patch-guar
...
jif/infty
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce9347388a | ||
|
|
d6515aa010 | ||
|
|
37b3807f96 | ||
|
|
7b8533fdbe | ||
|
|
5f47ab64c4 | ||
|
|
6915ba2100 | ||
|
|
50f53e7071 | ||
|
|
40fba1bb4c | ||
|
|
bdda762deb | ||
|
|
da5492694b | ||
|
|
a5d48a775b | ||
|
|
78f2785595 | ||
|
|
fc1723f131 | ||
|
|
ed5b0bfeb3 | ||
|
|
4b01f0f50a | ||
|
|
ac2b3ec2bb | ||
|
|
c052b89333 | ||
|
|
b424ca93ab | ||
|
|
32bd302d80 | ||
|
|
39c72b3151 | ||
|
|
2cdfd38c24 | ||
|
|
fc79a46c7a | ||
|
|
010dfa7751 | ||
|
|
54b9436699 | ||
|
|
af3bf801ce | ||
|
|
5fb6cbbcca | ||
|
|
7bdf63a009 | ||
|
|
119dabd272 | ||
|
|
c0baaa171b | ||
|
|
b45c204109 | ||
|
|
0139f6780c | ||
|
|
86ba270926 | ||
|
|
c146585cdb | ||
|
|
5fa7844ad7 | ||
|
|
84c9b574f9 | ||
|
|
272e13dd90 | ||
|
|
18d00e36b9 | ||
|
|
17550fee9e | ||
|
|
995f5c3614 | ||
|
|
9b53a306e3 | ||
|
|
0016346dfb | ||
|
|
f38ad65254 | ||
|
|
774892c6d7 | ||
|
|
897d4d5f17 | ||
|
|
b70dcd80a2 | ||
|
|
c0f8a49e3e | ||
|
|
87362d6ebd | ||
|
|
f073bc5ccf | ||
|
|
9320565658 | ||
|
|
4de5b25c52 | ||
|
|
90b2f096c3 | ||
|
|
f3c57ab888 | ||
|
|
43ee0dfd19 | ||
|
|
c9d9a40c98 | ||
|
|
ab3d607be4 | ||
|
|
f7d8e12ae0 | ||
|
|
a8278b5423 | ||
|
|
cb99d71f57 | ||
|
|
f72e9da7c5 | ||
|
|
732c435345 | ||
|
|
f5e055ae36 | ||
|
|
8245a4f53b |
63
.github/workflows/rust-release.yml
vendored
63
.github/workflows/rust-release.yml
vendored
@@ -206,6 +206,69 @@ jobs:
|
||||
codesign --force --options runtime --timestamp --sign "$APPLE_CODESIGN_IDENTITY" "${keychain_args[@]}" "$path"
|
||||
done
|
||||
|
||||
- if: ${{ matrix.runner == 'macos-14' }}
|
||||
name: Notarize macOS binaries
|
||||
shell: bash
|
||||
env:
|
||||
APPLE_NOTARIZATION_KEY_P8: ${{ secrets.APPLE_NOTARIZATION_KEY_P8 }}
|
||||
APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }}
|
||||
APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
for var in APPLE_NOTARIZATION_KEY_P8 APPLE_NOTARIZATION_KEY_ID APPLE_NOTARIZATION_ISSUER_ID; do
|
||||
if [[ -z "${!var:-}" ]]; then
|
||||
echo "$var is required for notarization"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
notary_key_path="${RUNNER_TEMP}/notarytool.key.p8"
|
||||
echo "$APPLE_NOTARIZATION_KEY_P8" | base64 -d > "$notary_key_path"
|
||||
cleanup_notary() {
|
||||
rm -f "$notary_key_path"
|
||||
}
|
||||
trap cleanup_notary EXIT
|
||||
|
||||
notarize_binary() {
|
||||
local binary="$1"
|
||||
local source_path="target/${{ matrix.target }}/release/${binary}"
|
||||
local archive_path="${RUNNER_TEMP}/${binary}.zip"
|
||||
|
||||
if [[ ! -f "$source_path" ]]; then
|
||||
echo "Binary $source_path not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -f "$archive_path"
|
||||
ditto -c -k --keepParent "$source_path" "$archive_path"
|
||||
|
||||
submission_json=$(xcrun notarytool submit "$archive_path" \
|
||||
--key "$notary_key_path" \
|
||||
--key-id "$APPLE_NOTARIZATION_KEY_ID" \
|
||||
--issuer "$APPLE_NOTARIZATION_ISSUER_ID" \
|
||||
--output-format json \
|
||||
--wait)
|
||||
|
||||
status=$(printf '%s\n' "$submission_json" | jq -r '.status // "Unknown"')
|
||||
submission_id=$(printf '%s\n' "$submission_json" | jq -r '.id // ""')
|
||||
|
||||
if [[ -z "$submission_id" ]]; then
|
||||
echo "Failed to retrieve submission ID for $binary"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "::notice title=Notarization::$binary submission ${submission_id} completed with status ${status}"
|
||||
|
||||
if [[ "$status" != "Accepted" ]]; then
|
||||
echo "Notarization failed for ${binary} (submission ${submission_id}, status ${status})"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
notarize_binary "codex"
|
||||
notarize_binary "codex-responses-api-proxy"
|
||||
|
||||
- name: Stage artifacts
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -30,6 +30,7 @@ result
|
||||
# cli tools
|
||||
CLAUDE.md
|
||||
.claude/
|
||||
AGENTS.override.md
|
||||
|
||||
# caches
|
||||
.cache/
|
||||
|
||||
@@ -12,6 +12,7 @@ In the codex-rs folder where the rust code lives:
|
||||
- Always inline format! args when possible per https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args
|
||||
- Use method references over closures when possible per https://rust-lang.github.io/rust-clippy/master/index.html#redundant_closure_for_method_calls
|
||||
- When writing tests, prefer comparing the equality of entire objects over fields one by one.
|
||||
- When making a change that adds or changes an API, ensure that the documentation in the `docs/` folder is up to date if applicable.
|
||||
|
||||
Run `just fmt` (in `codex-rs` directory) automatically after making Rust code changes; do not ask for approval to run it. Before finalizing a change to `codex-rs`, run `just fix -p <project>` (in `codex-rs` directory) to fix any linter issues in the code. Prefer scoping with `-p` to avoid slow workspace‑wide Clippy builds; only run `just fix` without `-p` if you changed shared crates. Additionally, run the tests:
|
||||
|
||||
|
||||
224
codex-rs/Cargo.lock
generated
224
codex-rs/Cargo.lock
generated
@@ -941,6 +941,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"assert_matches",
|
||||
"chrono",
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"codex-app-server",
|
||||
@@ -951,6 +952,7 @@ dependencies = [
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-exec",
|
||||
"codex-infty",
|
||||
"codex-login",
|
||||
"codex-mcp-server",
|
||||
"codex-process-hardening",
|
||||
@@ -959,14 +961,20 @@ dependencies = [
|
||||
"codex-responses-api-proxy",
|
||||
"codex-rmcp-client",
|
||||
"codex-tui",
|
||||
"crossterm",
|
||||
"ctor 0.5.0",
|
||||
"owo-colors",
|
||||
"predicates",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"supports-color",
|
||||
"tempfile",
|
||||
"textwrap 0.16.2",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1076,6 +1084,7 @@ dependencies = [
|
||||
"thiserror 2.0.16",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"toml",
|
||||
@@ -1144,6 +1153,17 @@ dependencies = [
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-feedback"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"codex-protocol",
|
||||
"pretty_assertions",
|
||||
"sentry",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-file-search"
|
||||
version = "0.0.0"
|
||||
@@ -1177,6 +1197,27 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-infty"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"codex-core",
|
||||
"codex-protocol",
|
||||
"core_test_support",
|
||||
"dirs",
|
||||
"futures",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-linux-sandbox"
|
||||
version = "0.0.0"
|
||||
@@ -1354,6 +1395,7 @@ dependencies = [
|
||||
"axum",
|
||||
"codex-protocol",
|
||||
"dirs",
|
||||
"escargot",
|
||||
"futures",
|
||||
"keyring",
|
||||
"mcp-types",
|
||||
@@ -1363,6 +1405,7 @@ dependencies = [
|
||||
"rmcp",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serial_test",
|
||||
"sha2",
|
||||
"tempfile",
|
||||
"tiny_http",
|
||||
@@ -1388,6 +1431,7 @@ dependencies = [
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"codex-feedback",
|
||||
"codex-file-search",
|
||||
"codex-git-tooling",
|
||||
"codex-login",
|
||||
@@ -1422,6 +1466,7 @@ dependencies = [
|
||||
"textwrap 0.16.2",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"toml",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
@@ -1823,6 +1868,16 @@ version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
|
||||
|
||||
[[package]]
|
||||
name = "debugid"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugserver-types"
|
||||
version = "0.5.0"
|
||||
@@ -2301,6 +2356,18 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "findshlibs"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40b9e59cd0f7e0806cca4be089683ecb6434e602038df21fe6bf6711b2f07f64"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fixed_decimal"
|
||||
version = "0.7.0"
|
||||
@@ -2669,6 +2736,17 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hostname"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a56f203cd1c76362b69e3863fd987520ac36cf70a8c92627449b2f64a8cf7d65"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"windows-link 0.1.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.3.1"
|
||||
@@ -4905,6 +4983,15 @@ version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
|
||||
dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
@@ -5219,6 +5306,120 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
|
||||
|
||||
[[package]]
|
||||
name = "sentry"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5484316556650182f03b43d4c746ce0e3e48074a21e2f51244b648b6542e1066"
|
||||
dependencies = [
|
||||
"httpdate",
|
||||
"native-tls",
|
||||
"reqwest",
|
||||
"sentry-backtrace",
|
||||
"sentry-contexts",
|
||||
"sentry-core",
|
||||
"sentry-debug-images",
|
||||
"sentry-panic",
|
||||
"sentry-tracing",
|
||||
"tokio",
|
||||
"ureq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-backtrace"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40aa225bb41e2ec9d7c90886834367f560efc1af028f1c5478a6cce6a59c463a"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"sentry-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-contexts"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a8dd746da3d16cb8c39751619cefd4fcdbd6df9610f3310fd646b55f6e39910"
|
||||
dependencies = [
|
||||
"hostname",
|
||||
"libc",
|
||||
"os_info",
|
||||
"rustc_version",
|
||||
"sentry-core",
|
||||
"uname",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-core"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "161283cfe8e99c8f6f236a402b9ccf726b201f365988b5bb637ebca0abbd4a30"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"sentry-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-debug-images"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fc6b25e945fcaa5e97c43faee0267eebda9f18d4b09a251775d8fef1086238a"
|
||||
dependencies = [
|
||||
"findshlibs",
|
||||
"once_cell",
|
||||
"sentry-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-panic"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc74f229c7186dd971a9491ffcbe7883544aa064d1589bd30b83fb856cd22d63"
|
||||
dependencies = [
|
||||
"sentry-backtrace",
|
||||
"sentry-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-tracing"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd3c5faf2103cd01eeda779ea439b68c4ee15adcdb16600836e97feafab362ec"
|
||||
dependencies = [
|
||||
"sentry-backtrace",
|
||||
"sentry-core",
|
||||
"tracing-core",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-types"
|
||||
version = "0.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d68cdf6bc41b8ff3ae2a9c4671e97426dcdd154cc1d4b6b72813f285d6b163f"
|
||||
dependencies = [
|
||||
"debugid",
|
||||
"hex",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"time",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.226"
|
||||
@@ -6073,6 +6274,7 @@ dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6426,6 +6628,15 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uname"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b72f89f0ca32e4db1c04e2a72f5345d59796d4866a1ee0609084569f73683dc8"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicase"
|
||||
version = "2.8.1"
|
||||
@@ -6485,6 +6696,19 @@ version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"log",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.4"
|
||||
|
||||
@@ -6,6 +6,7 @@ members = [
|
||||
"app-server-protocol",
|
||||
"apply-patch",
|
||||
"arg0",
|
||||
"codex-infty",
|
||||
"codex-backend-openapi-models",
|
||||
"cloud-tasks",
|
||||
"cloud-tasks-client",
|
||||
@@ -14,6 +15,7 @@ members = [
|
||||
"core",
|
||||
"exec",
|
||||
"execpolicy",
|
||||
"feedback",
|
||||
"file-search",
|
||||
"git-tooling",
|
||||
"linux-sandbox",
|
||||
@@ -56,6 +58,7 @@ codex-chatgpt = { path = "chatgpt" }
|
||||
codex-common = { path = "common" }
|
||||
codex-core = { path = "core" }
|
||||
codex-exec = { path = "exec" }
|
||||
codex-feedback = { path = "feedback" }
|
||||
codex-file-search = { path = "file-search" }
|
||||
codex-git-tooling = { path = "git-tooling" }
|
||||
codex-linux-sandbox = { path = "linux-sandbox" }
|
||||
@@ -83,8 +86,8 @@ ansi-to-tui = "7.0.0"
|
||||
anyhow = "1"
|
||||
arboard = "3"
|
||||
askama = "0.12"
|
||||
assert_matches = "1.5.0"
|
||||
assert_cmd = "2"
|
||||
assert_matches = "1.5.0"
|
||||
async-channel = "2.3.1"
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1.89"
|
||||
@@ -147,6 +150,7 @@ reqwest = "0.12"
|
||||
rmcp = { version = "0.8.0", default-features = false }
|
||||
schemars = "0.8.22"
|
||||
seccompiler = "0.5.0"
|
||||
sentry = "0.34.0"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
serde_with = "3.14"
|
||||
|
||||
@@ -9,6 +9,7 @@ use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_protocol::parse_command::ParsedCommand;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::FileChange;
|
||||
@@ -697,6 +698,7 @@ pub struct ExecCommandApprovalParams {
|
||||
pub cwd: PathBuf,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reason: Option<String>,
|
||||
pub parsed_cmd: Vec<ParsedCommand>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
@@ -904,6 +906,9 @@ mod tests {
|
||||
command: vec!["echo".to_string(), "hello".to_string()],
|
||||
cwd: PathBuf::from("/tmp"),
|
||||
reason: Some("because tests".to_string()),
|
||||
parsed_cmd: vec![ParsedCommand::Unknown {
|
||||
cmd: "echo hello".to_string(),
|
||||
}],
|
||||
};
|
||||
let request = ServerRequest::ExecCommandApproval {
|
||||
request_id: RequestId::Integer(7),
|
||||
@@ -920,6 +925,12 @@ mod tests {
|
||||
"command": ["echo", "hello"],
|
||||
"cwd": "/tmp",
|
||||
"reason": "because tests",
|
||||
"parsedCmd": [
|
||||
{
|
||||
"type": "unknown",
|
||||
"cmd": "echo hello"
|
||||
}
|
||||
]
|
||||
}
|
||||
}),
|
||||
serde_json::to_value(&request)?,
|
||||
|
||||
@@ -1284,6 +1284,7 @@ async fn apply_bespoke_event_handling(
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
parsed_cmd,
|
||||
}) => {
|
||||
let params = ExecCommandApprovalParams {
|
||||
conversation_id,
|
||||
@@ -1291,6 +1292,7 @@ async fn apply_bespoke_event_handling(
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
parsed_cmd,
|
||||
};
|
||||
let rx = outgoing
|
||||
.send_request(ServerRequestPayload::ExecCommandApproval(params))
|
||||
|
||||
@@ -27,6 +27,7 @@ use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use codex_core::protocol_config_types::ReasoningSummary;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::parse_command::ParsedCommand;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::InputMessageKind;
|
||||
@@ -311,6 +312,9 @@ async fn test_send_user_turn_changes_approval_policy_behavior() {
|
||||
],
|
||||
cwd: working_directory.clone(),
|
||||
reason: None,
|
||||
parsed_cmd: vec![ParsedCommand::Unknown {
|
||||
cmd: "python3 -c 'print(42)'".to_string()
|
||||
}],
|
||||
},
|
||||
params
|
||||
);
|
||||
|
||||
@@ -35,6 +35,7 @@ codex-tui = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-cloud-tasks = { path = "../cloud-tasks" }
|
||||
ctor = { workspace = true }
|
||||
crossterm = { workspace = true }
|
||||
owo-colors = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
supports-color = { workspace = true }
|
||||
@@ -45,6 +46,13 @@ tokio = { workspace = true, features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
codex-infty = { path = "../codex-infty" }
|
||||
chrono = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
tracing = "0.1.41"
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
textwrap = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = { workspace = true }
|
||||
|
||||
115
codex-rs/cli/src/infty/args.rs
Normal file
115
codex-rs/cli/src/infty/args.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use clap::Subcommand;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
|
||||
use super::commands;
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
pub struct InftyCli {
|
||||
#[clap(flatten)]
|
||||
pub config_overrides: CliConfigOverrides,
|
||||
|
||||
/// Override the default runs root (`~/.codex/infty`).
|
||||
#[arg(long = "runs-root", value_name = "DIR")]
|
||||
pub runs_root: Option<PathBuf>,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: InftyCommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum InftyCommand {
|
||||
/// Create a new run store and spawn solver/director sessions.
|
||||
Create(CreateArgs),
|
||||
|
||||
/// List stored runs.
|
||||
List(ListArgs),
|
||||
|
||||
/// Show metadata for a stored run.
|
||||
Show(ShowArgs),
|
||||
// resumable runs are disabled; Drive command removed
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
pub(crate) struct CreateArgs {
|
||||
/// Explicit run id. If omitted, a timestamp-based id is generated.
|
||||
#[arg(long = "run-id", value_name = "RUN_ID")]
|
||||
pub run_id: Option<String>,
|
||||
|
||||
/// Optional objective to send to the solver immediately after creation.
|
||||
#[arg(long)]
|
||||
pub objective: Option<String>,
|
||||
|
||||
/// Timeout in seconds when waiting for the solver reply to --objective.
|
||||
#[arg(long = "timeout-secs", default_value_t = super::commands::DEFAULT_TIMEOUT_SECS)]
|
||||
pub timeout_secs: u64,
|
||||
|
||||
/// Override only the Director's model (solver and verifiers keep defaults).
|
||||
#[arg(long = "director-model", value_name = "MODEL")]
|
||||
pub director_model: Option<String>,
|
||||
|
||||
/// Override only the Director's reasoning effort (minimal|low|medium|high).
|
||||
#[arg(
|
||||
long = "director-effort",
|
||||
value_name = "LEVEL",
|
||||
value_parser = parse_reasoning_effort
|
||||
)]
|
||||
pub director_effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
pub(crate) struct ListArgs {
|
||||
/// Emit JSON describing the stored runs.
|
||||
#[arg(long)]
|
||||
pub json: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
pub(crate) struct ShowArgs {
|
||||
/// Run id to display.
|
||||
#[arg(value_name = "RUN_ID")]
|
||||
pub run_id: String,
|
||||
|
||||
/// Emit JSON metadata instead of human-readable text.
|
||||
#[arg(long)]
|
||||
pub json: bool,
|
||||
}
|
||||
|
||||
// resumable runs are disabled; DriveArgs removed
|
||||
|
||||
impl InftyCli {
|
||||
pub async fn run(self) -> Result<()> {
|
||||
let InftyCli {
|
||||
config_overrides,
|
||||
runs_root,
|
||||
command,
|
||||
} = self;
|
||||
|
||||
match command {
|
||||
InftyCommand::Create(args) => {
|
||||
commands::run_create(config_overrides, runs_root, args).await?;
|
||||
}
|
||||
InftyCommand::List(args) => commands::run_list(runs_root, args)?,
|
||||
InftyCommand::Show(args) => commands::run_show(runs_root, args)?,
|
||||
// Drive removed
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_reasoning_effort(s: &str) -> Result<ReasoningEffort, String> {
|
||||
match s.trim().to_ascii_lowercase().as_str() {
|
||||
"minimal" => Ok(ReasoningEffort::Minimal),
|
||||
"low" => Ok(ReasoningEffort::Low),
|
||||
"medium" => Ok(ReasoningEffort::Medium),
|
||||
"high" => Ok(ReasoningEffort::High),
|
||||
_ => Err(format!(
|
||||
"invalid reasoning effort: {s}. Expected one of: minimal|low|medium|high"
|
||||
)),
|
||||
}
|
||||
}
|
||||
438
codex-rs/cli/src/infty/commands.rs
Normal file
438
codex-rs/cli/src/infty/commands.rs
Normal file
@@ -0,0 +1,438 @@
|
||||
use std::fs;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use chrono::SecondsFormat;
|
||||
use chrono::Utc;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::auth::read_codex_api_key_from_env;
|
||||
use codex_core::auth::read_openai_api_key_from_env;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_infty::InftyOrchestrator;
|
||||
use codex_infty::RoleConfig;
|
||||
use codex_infty::RunExecutionOptions;
|
||||
use codex_infty::RunParams;
|
||||
use codex_infty::RunStore;
|
||||
use owo_colors::OwoColorize;
|
||||
use serde::Serialize;
|
||||
use std::sync::OnceLock;
|
||||
use supports_color::Stream;
|
||||
use tracing_appender::non_blocking;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
use super::args::CreateArgs;
|
||||
use super::args::ListArgs;
|
||||
use super::args::ShowArgs;
|
||||
use super::progress::TerminalProgressReporter;
|
||||
use super::summary::print_run_summary_box;
|
||||
|
||||
const DEFAULT_VERIFIER_ROLES: [&str; 3] = ["verifier-alpha", "verifier-beta", "verifier-gamma"];
|
||||
|
||||
pub(crate) const DEFAULT_TIMEOUT_SECS: u64 = 6000;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RunSummary {
|
||||
run_id: String,
|
||||
path: String,
|
||||
created_at: String,
|
||||
updated_at: String,
|
||||
roles: Vec<String>,
|
||||
}
|
||||
|
||||
pub(crate) async fn run_create(
|
||||
config_overrides: CliConfigOverrides,
|
||||
runs_root_override: Option<PathBuf>,
|
||||
args: CreateArgs,
|
||||
) -> Result<()> {
|
||||
let config = load_config(config_overrides).await?;
|
||||
init_infty_logging(&config)?;
|
||||
let auth = load_auth(&config)?;
|
||||
let runs_root = resolve_runs_root(runs_root_override)?;
|
||||
let color_enabled = supports_color::on(Stream::Stdout).is_some();
|
||||
|
||||
let mut run_id = if let Some(id) = args.run_id {
|
||||
id
|
||||
} else {
|
||||
generate_run_id()
|
||||
};
|
||||
run_id = run_id.trim().to_string();
|
||||
validate_run_id(&run_id)?;
|
||||
|
||||
let run_path = runs_root.join(&run_id);
|
||||
if run_path.exists() {
|
||||
bail!("run {run_id} already exists at {}", run_path.display());
|
||||
}
|
||||
|
||||
let orchestrator = InftyOrchestrator::with_runs_root(auth, runs_root).with_progress(Arc::new(
|
||||
TerminalProgressReporter::with_color(color_enabled),
|
||||
));
|
||||
let verifiers: Vec<RoleConfig> = DEFAULT_VERIFIER_ROLES
|
||||
.iter()
|
||||
.map(|role| RoleConfig::new(role.to_string(), config.clone()))
|
||||
.collect();
|
||||
let mut director_config = config.clone();
|
||||
if let Some(model) = args.director_model.as_deref() {
|
||||
director_config.model = model.to_string();
|
||||
}
|
||||
if let Some(effort) = args.director_effort {
|
||||
director_config.model_reasoning_effort = Some(effort);
|
||||
}
|
||||
let run_params = RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(run_path.clone()),
|
||||
solver: RoleConfig::new("solver", config.clone()),
|
||||
director: RoleConfig::new("director", director_config),
|
||||
verifiers,
|
||||
};
|
||||
|
||||
if let Some(objective) = args.objective {
|
||||
let timeout = Duration::from_secs(args.timeout_secs);
|
||||
let options = RunExecutionOptions {
|
||||
objective: Some(objective),
|
||||
director_timeout: timeout,
|
||||
verifier_timeout: timeout,
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
let start_header = format!("Starting run {run_id}");
|
||||
if color_enabled {
|
||||
println!("{}", start_header.blue().bold());
|
||||
} else {
|
||||
println!("{start_header}");
|
||||
}
|
||||
let location_line = format!(" run directory: {}", run_path.display());
|
||||
if color_enabled {
|
||||
println!("{}", location_line.dimmed());
|
||||
} else {
|
||||
println!("{location_line}");
|
||||
}
|
||||
if let Some(objective_text) = options.objective.as_deref()
|
||||
&& !objective_text.trim().is_empty()
|
||||
{
|
||||
let objective_line = format!(" objective: {objective_text}");
|
||||
if color_enabled {
|
||||
println!("{}", objective_line.dimmed());
|
||||
} else {
|
||||
println!("{objective_line}");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
|
||||
let objective_snapshot = options.objective.clone();
|
||||
let outcome = orchestrator
|
||||
.execute_new_run(run_params, options)
|
||||
.await
|
||||
.with_context(|| format!("failed to execute run {run_id}"))?;
|
||||
let duration = start.elapsed();
|
||||
print_run_summary_box(
|
||||
color_enabled,
|
||||
&run_id,
|
||||
&run_path,
|
||||
&outcome.deliverable_path,
|
||||
outcome.summary.as_deref(),
|
||||
objective_snapshot.as_deref(),
|
||||
duration,
|
||||
);
|
||||
} else {
|
||||
let sessions = orchestrator
|
||||
.spawn_run(run_params)
|
||||
.await
|
||||
.with_context(|| format!("failed to create run {run_id}"))?;
|
||||
|
||||
println!(
|
||||
"Created run {run_id} at {}",
|
||||
sessions.store.path().display()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn run_list(runs_root_override: Option<PathBuf>, args: ListArgs) -> Result<()> {
|
||||
// Initialize logging using default Codex home discovery.
|
||||
let _ = init_infty_logging_from_home();
|
||||
let runs_root = resolve_runs_root(runs_root_override)?;
|
||||
let listings = collect_run_summaries(&runs_root)?;
|
||||
|
||||
if args.json {
|
||||
println!("{}", serde_json::to_string_pretty(&listings)?);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if listings.is_empty() {
|
||||
println!("No runs found under {}", runs_root.display());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("Runs in {}", runs_root.display());
|
||||
for summary in listings {
|
||||
println!(
|
||||
"{}\t{}\t{}",
|
||||
summary.run_id, summary.updated_at, summary.path
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn run_show(runs_root_override: Option<PathBuf>, args: ShowArgs) -> Result<()> {
|
||||
validate_run_id(&args.run_id)?;
|
||||
let _ = init_infty_logging_from_home();
|
||||
let runs_root = resolve_runs_root(runs_root_override)?;
|
||||
let run_path = runs_root.join(&args.run_id);
|
||||
let store =
|
||||
RunStore::load(&run_path).with_context(|| format!("failed to load run {}", args.run_id))?;
|
||||
let metadata = store.metadata();
|
||||
|
||||
let summary = RunSummary {
|
||||
run_id: metadata.run_id.clone(),
|
||||
path: run_path.display().to_string(),
|
||||
created_at: metadata
|
||||
.created_at
|
||||
.to_rfc3339_opts(SecondsFormat::Secs, true),
|
||||
updated_at: metadata
|
||||
.updated_at
|
||||
.to_rfc3339_opts(SecondsFormat::Secs, true),
|
||||
roles: metadata
|
||||
.roles
|
||||
.iter()
|
||||
.map(|role| role.role.clone())
|
||||
.collect(),
|
||||
};
|
||||
|
||||
if args.json {
|
||||
println!("{}", serde_json::to_string_pretty(&summary)?);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("Run: {}", summary.run_id);
|
||||
println!("Path: {}", summary.path);
|
||||
println!("Created: {}", summary.created_at);
|
||||
println!("Updated: {}", summary.updated_at);
|
||||
println!("Roles: {}", summary.roles.join(", "));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// resumable runs are disabled; run_drive removed
|
||||
|
||||
fn generate_run_id() -> String {
|
||||
let timestamp = Utc::now().format("run-%Y%m%d-%H%M%S");
|
||||
format!("{timestamp}")
|
||||
}
|
||||
|
||||
pub(crate) fn validate_run_id(run_id: &str) -> Result<()> {
|
||||
if run_id.is_empty() {
|
||||
bail!("run id must not be empty");
|
||||
}
|
||||
if run_id.starts_with('.') || run_id.ends_with('.') {
|
||||
bail!("run id must not begin or end with '.'");
|
||||
}
|
||||
if run_id
|
||||
.chars()
|
||||
.any(|c| !(c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.')))
|
||||
{
|
||||
bail!("run id may only contain ASCII alphanumerics, '-', '_', or '.'");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_config(cli_overrides: CliConfigOverrides) -> Result<Config> {
|
||||
let overrides = cli_overrides
|
||||
.parse_overrides()
|
||||
.map_err(|err| anyhow!("failed to parse -c overrides: {err}"))?;
|
||||
Config::load_with_cli_overrides(overrides, ConfigOverrides::default())
|
||||
.await
|
||||
.context("failed to load Codex configuration")
|
||||
}
|
||||
|
||||
fn load_auth(config: &Config) -> Result<CodexAuth> {
|
||||
if let Some(auth) =
|
||||
CodexAuth::from_codex_home(&config.codex_home).context("failed to read auth.json")?
|
||||
{
|
||||
return Ok(auth);
|
||||
}
|
||||
if let Some(api_key) = read_codex_api_key_from_env() {
|
||||
return Ok(CodexAuth::from_api_key(&api_key));
|
||||
}
|
||||
if let Some(api_key) = read_openai_api_key_from_env() {
|
||||
return Ok(CodexAuth::from_api_key(&api_key));
|
||||
}
|
||||
bail!("no Codex authentication found. Run `codex login` or set OPENAI_API_KEY.");
|
||||
}
|
||||
|
||||
fn resolve_runs_root(override_path: Option<PathBuf>) -> Result<PathBuf> {
|
||||
if let Some(path) = override_path {
|
||||
return Ok(path);
|
||||
}
|
||||
codex_infty::default_runs_root()
|
||||
}
|
||||
|
||||
fn collect_run_summaries(root: &Path) -> Result<Vec<RunSummary>> {
|
||||
let mut summaries = Vec::new();
|
||||
let iter = match fs::read_dir(root) {
|
||||
Ok(read_dir) => read_dir,
|
||||
Err(err) if err.kind() == io::ErrorKind::NotFound => return Ok(summaries),
|
||||
Err(err) => {
|
||||
return Err(
|
||||
anyhow!(err).context(format!("failed to read runs root {}", root.display()))
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
for entry in iter {
|
||||
let entry = entry?;
|
||||
if !entry.file_type()?.is_dir() {
|
||||
continue;
|
||||
}
|
||||
let run_path = entry.path();
|
||||
let store = match RunStore::load(&run_path) {
|
||||
Ok(store) => store,
|
||||
Err(err) => {
|
||||
eprintln!(
|
||||
"Skipping {}: failed to load run metadata: {err}",
|
||||
run_path.display()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let metadata = store.metadata();
|
||||
summaries.push(RunSummary {
|
||||
run_id: metadata.run_id.clone(),
|
||||
path: run_path.display().to_string(),
|
||||
created_at: metadata
|
||||
.created_at
|
||||
.to_rfc3339_opts(SecondsFormat::Secs, true),
|
||||
updated_at: metadata
|
||||
.updated_at
|
||||
.to_rfc3339_opts(SecondsFormat::Secs, true),
|
||||
roles: metadata
|
||||
.roles
|
||||
.iter()
|
||||
.map(|role| role.role.clone())
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
|
||||
summaries.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
|
||||
Ok(summaries)
|
||||
}
|
||||
|
||||
fn init_infty_logging(config: &codex_core::config::Config) -> std::io::Result<()> {
|
||||
let log_dir = codex_core::config::log_dir(config)?;
|
||||
std::fs::create_dir_all(&log_dir)?;
|
||||
|
||||
let mut log_file_opts = OpenOptions::new();
|
||||
log_file_opts.create(true).append(true);
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
log_file_opts.mode(0o600);
|
||||
}
|
||||
|
||||
let log_file = log_file_opts.open(log_dir.join("codex-infty.log"))?;
|
||||
let (non_blocking, guard) = non_blocking(log_file);
|
||||
static INFTY_LOG_GUARD: OnceLock<tracing_appender::non_blocking::WorkerGuard> = OnceLock::new();
|
||||
let _ = INFTY_LOG_GUARD.set(guard);
|
||||
|
||||
// Use RUST_LOG if set, otherwise default to info for common codex crates
|
||||
let env_filter = || {
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("codex_core=info,codex_infty=info,codex_cli=info"))
|
||||
};
|
||||
|
||||
let file_layer = tracing_subscriber::fmt::layer()
|
||||
.with_writer(non_blocking)
|
||||
.with_target(false)
|
||||
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
|
||||
.with_filter(env_filter());
|
||||
|
||||
// Initialize once; subsequent calls are no‑ops.
|
||||
let _ = tracing_subscriber::registry().with(file_layer).try_init();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_infty_logging_from_home() -> std::io::Result<()> {
|
||||
let mut log_dir = codex_core::config::find_codex_home()?;
|
||||
log_dir.push("log");
|
||||
std::fs::create_dir_all(&log_dir)?;
|
||||
|
||||
let mut log_file_opts = OpenOptions::new();
|
||||
log_file_opts.create(true).append(true);
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
log_file_opts.mode(0o600);
|
||||
}
|
||||
|
||||
let log_file = log_file_opts.open(log_dir.join("codex-infty.log"))?;
|
||||
let (non_blocking, guard) = non_blocking(log_file);
|
||||
static INFTY_LOG_GUARD: OnceLock<tracing_appender::non_blocking::WorkerGuard> = OnceLock::new();
|
||||
let _ = INFTY_LOG_GUARD.set(guard);
|
||||
|
||||
let env_filter = || {
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("codex_core=info,codex_infty=info,codex_cli=info"))
|
||||
};
|
||||
|
||||
let file_layer = tracing_subscriber::fmt::layer()
|
||||
.with_writer(non_blocking)
|
||||
.with_target(false)
|
||||
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
|
||||
.with_filter(env_filter());
|
||||
|
||||
let _ = tracing_subscriber::registry().with(file_layer).try_init();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn default_verifier_roles_are_stable() {
|
||||
assert_eq!(
|
||||
DEFAULT_VERIFIER_ROLES,
|
||||
["verifier-alpha", "verifier-beta", "verifier-gamma"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_run_ids() {
|
||||
assert!(validate_run_id("run-20241030-123000").is_ok());
|
||||
assert!(validate_run_id("run.alpha").is_ok());
|
||||
assert!(validate_run_id("").is_err());
|
||||
assert!(validate_run_id("..bad").is_err());
|
||||
assert!(validate_run_id("bad/value").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generates_timestamped_run_id() {
|
||||
let run_id = generate_run_id();
|
||||
assert!(run_id.starts_with("run-"));
|
||||
assert_eq!(run_id.len(), "run-YYYYMMDD-HHMMSS".len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_summaries_returns_empty_for_missing_root() {
|
||||
let temp = TempDir::new().expect("temp dir");
|
||||
let missing = temp.path().join("not-present");
|
||||
let summaries = collect_run_summaries(&missing).expect("collect");
|
||||
assert!(summaries.is_empty());
|
||||
}
|
||||
}
|
||||
6
codex-rs/cli/src/infty/mod.rs
Normal file
6
codex-rs/cli/src/infty/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
mod args;
|
||||
mod commands;
|
||||
mod progress;
|
||||
mod summary;
|
||||
|
||||
pub use args::InftyCli;
|
||||
194
codex-rs/cli/src/infty/progress.rs
Normal file
194
codex-rs/cli/src/infty/progress.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use chrono::Local;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_infty::AggregatedVerifierVerdict;
|
||||
use codex_infty::DirectiveResponse;
|
||||
use codex_infty::ProgressReporter;
|
||||
use codex_infty::VerifierDecision;
|
||||
use codex_infty::VerifierVerdict;
|
||||
use crossterm::style::Stylize;
|
||||
use std::path::Path;
|
||||
use supports_color::Stream;
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub(crate) struct TerminalProgressReporter;
|
||||
|
||||
impl TerminalProgressReporter {
|
||||
pub(crate) fn with_color(_color_enabled: bool) -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
fn format_role_label(&self, role: &str) -> String {
|
||||
let lower = role.to_ascii_lowercase();
|
||||
if lower == "solver" {
|
||||
return "[solver]".magenta().bold().to_string();
|
||||
}
|
||||
if lower == "director" {
|
||||
return "[director]".blue().bold().to_string();
|
||||
}
|
||||
if lower == "user" {
|
||||
return "[user]".cyan().bold().to_string();
|
||||
}
|
||||
if lower.contains("verifier") {
|
||||
return format!("[{role}]").green().bold().to_string();
|
||||
}
|
||||
format!("[{role}]").magenta().bold().to_string()
|
||||
}
|
||||
|
||||
fn timestamp(&self) -> String {
|
||||
let timestamp = Local::now().format("%H:%M:%S");
|
||||
let display = format!("[{timestamp}]");
|
||||
if supports_color::on(Stream::Stdout).is_some() {
|
||||
format!("{}", display.dim())
|
||||
} else {
|
||||
display
|
||||
}
|
||||
}
|
||||
|
||||
fn print_exchange(
|
||||
&self,
|
||||
from_role: &str,
|
||||
to_role: &str,
|
||||
lines: Vec<String>,
|
||||
trailing_blank_line: bool,
|
||||
) {
|
||||
let header = format!(
|
||||
"{} ----> {}",
|
||||
self.format_role_label(from_role),
|
||||
self.format_role_label(to_role)
|
||||
);
|
||||
println!("{} {header}", self.timestamp());
|
||||
for line in lines {
|
||||
println!("{line}");
|
||||
}
|
||||
if trailing_blank_line {
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
fn format_decision(&self, decision: VerifierDecision) -> String {
|
||||
match decision {
|
||||
VerifierDecision::Pass => "pass".green().bold().to_string(),
|
||||
VerifierDecision::Fail => "fail".red().bold().to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProgressReporter for TerminalProgressReporter {
|
||||
fn objective_posted(&self, objective: &str) {
|
||||
let objective_line = format!("{}", format!("→ objective: {objective}").dim());
|
||||
self.print_exchange("user", "solver", vec![objective_line], true);
|
||||
}
|
||||
|
||||
fn solver_event(&self, event: &EventMsg) {
|
||||
match serde_json::to_string_pretty(event) {
|
||||
Ok(json) => {
|
||||
tracing::debug!("[solver:event]\n{json}");
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!("[solver:event] (failed to serialize: {err}) {event:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn role_event(&self, role: &str, event: &EventMsg) {
|
||||
match serde_json::to_string_pretty(event) {
|
||||
Ok(json) => {
|
||||
tracing::debug!("[{role}:event]\n{json}");
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!("[{role}:event] (failed to serialize: {err}) {event:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn solver_agent_message(&self, agent_msg: &AgentMessageEvent) {
|
||||
tracing::info!("Agent Message: {agent_msg:?}");
|
||||
}
|
||||
|
||||
fn invalid_solver_signal(&self, raw_message: &str) {
|
||||
let heading = "Warning".yellow().bold();
|
||||
let body = format!(
|
||||
"solver reply did not match expected JSON signal; got: {}",
|
||||
raw_message
|
||||
);
|
||||
println!("{} {} {}", self.timestamp(), heading, body);
|
||||
}
|
||||
|
||||
fn direction_request(&self, prompt: &str) {
|
||||
let prompt_line = format!("{}", prompt.yellow());
|
||||
self.print_exchange("solver", "director", vec![prompt_line], true);
|
||||
}
|
||||
|
||||
fn director_response(&self, directive: &DirectiveResponse) {
|
||||
let suffix = directive
|
||||
.rationale
|
||||
.as_deref()
|
||||
.filter(|rationale| !rationale.is_empty())
|
||||
.map(|rationale| format!(" (rationale: {rationale})"))
|
||||
.unwrap_or_default();
|
||||
let directive_line = format!("{}{}", directive.directive, suffix);
|
||||
self.print_exchange("director", "solver", vec![directive_line], true);
|
||||
}
|
||||
|
||||
fn verification_request(&self, claim_path: &str, notes: Option<&str>) {
|
||||
let mut lines = Vec::new();
|
||||
let path_line = format!("→ path: {claim_path}");
|
||||
lines.push(format!("{}", path_line.dim()));
|
||||
if let Some(notes) = notes.filter(|notes| !notes.is_empty()) {
|
||||
let note_line = format!("→ note: {notes}");
|
||||
lines.push(format!("{}", note_line.dim()));
|
||||
}
|
||||
self.print_exchange("solver", "verifier", lines, true);
|
||||
}
|
||||
|
||||
fn verifier_verdict(&self, role: &str, verdict: &VerifierVerdict) {
|
||||
let decision = self.format_decision(verdict.verdict);
|
||||
let mut lines = Vec::new();
|
||||
lines.push(format!("verdict: {decision}"));
|
||||
if !verdict.reasons.is_empty() {
|
||||
let reasons = verdict.reasons.join("; ");
|
||||
let reason_line = format!("→ reasons: {reasons}");
|
||||
lines.push(format!("{}", reason_line.dim()));
|
||||
}
|
||||
if !verdict.suggestions.is_empty() {
|
||||
let suggestions = verdict.suggestions.join("; ");
|
||||
let suggestion_line = format!("→ suggestions: {suggestions}");
|
||||
lines.push(format!("{}", suggestion_line.dim()));
|
||||
}
|
||||
self.print_exchange(role, "solver", lines, false);
|
||||
}
|
||||
|
||||
fn verification_summary(&self, summary: &AggregatedVerifierVerdict) {
|
||||
let decision = self.format_decision(summary.overall);
|
||||
let heading = "Verification summary".bold();
|
||||
let summary_line = format!("{heading}: {decision}");
|
||||
self.print_exchange("verifier", "solver", vec![summary_line], true);
|
||||
}
|
||||
|
||||
fn final_delivery(&self, deliverable_path: &Path, summary: Option<&str>) {
|
||||
let delivery_line = format!(
|
||||
"{}",
|
||||
format!("→ path: {}", deliverable_path.display()).dim()
|
||||
);
|
||||
let summary_line = format!(
|
||||
"{}",
|
||||
format!("→ summary: {}", summary.unwrap_or("<none>")).dim()
|
||||
);
|
||||
self.print_exchange(
|
||||
"solver",
|
||||
"verifier",
|
||||
vec![delivery_line, summary_line],
|
||||
true,
|
||||
);
|
||||
}
|
||||
|
||||
fn run_interrupted(&self) {
|
||||
println!(
|
||||
"{}",
|
||||
"Run interrupted by Ctrl+C. Shutting down sessions…"
|
||||
.red()
|
||||
.bold(),
|
||||
);
|
||||
}
|
||||
}
|
||||
123
codex-rs/cli/src/infty/summary.rs
Normal file
123
codex-rs/cli/src/infty/summary.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_common::elapsed::format_duration;
|
||||
use crossterm::terminal;
|
||||
use owo_colors::OwoColorize;
|
||||
use textwrap::Options as WrapOptions;
|
||||
use textwrap::wrap;
|
||||
|
||||
pub(crate) fn print_run_summary_box(
|
||||
color_enabled: bool,
|
||||
run_id: &str,
|
||||
run_path: &Path,
|
||||
deliverable_path: &Path,
|
||||
summary: Option<&str>,
|
||||
objective: Option<&str>,
|
||||
duration: Duration,
|
||||
) {
|
||||
let mut items = Vec::new();
|
||||
items.push(("Run ID".to_string(), run_id.to_string()));
|
||||
items.push(("Run Directory".to_string(), run_path.display().to_string()));
|
||||
if let Some(objective) = objective
|
||||
&& !objective.trim().is_empty()
|
||||
{
|
||||
items.push(("Objective".to_string(), objective.trim().to_string()));
|
||||
}
|
||||
items.push((
|
||||
"Deliverable".to_string(),
|
||||
deliverable_path.display().to_string(),
|
||||
));
|
||||
items.push(("Total Time".to_string(), format_duration(duration)));
|
||||
if let Some(summary) = summary {
|
||||
let trimmed = summary.trim();
|
||||
if !trimmed.is_empty() {
|
||||
items.push(("Summary".to_string(), trimmed.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let label_width = items
|
||||
.iter()
|
||||
.map(|(label, _)| label.len())
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
.max(12);
|
||||
|
||||
const DEFAULT_MAX_WIDTH: usize = 84;
|
||||
const MIN_VALUE_WIDTH: usize = 20;
|
||||
let label_padding = label_width + 7;
|
||||
let min_total_width = label_padding + MIN_VALUE_WIDTH;
|
||||
let available_width = terminal::size()
|
||||
.ok()
|
||||
.map(|(cols, _)| usize::from(cols).saturating_sub(2))
|
||||
.unwrap_or(DEFAULT_MAX_WIDTH);
|
||||
let max_width = available_width.min(DEFAULT_MAX_WIDTH);
|
||||
let lower_bound = min_total_width.min(available_width);
|
||||
let mut total_width = max_width.max(lower_bound).max(label_padding + 1);
|
||||
let mut value_width = total_width.saturating_sub(label_padding);
|
||||
if value_width < MIN_VALUE_WIDTH {
|
||||
value_width = MIN_VALUE_WIDTH;
|
||||
total_width = label_padding + value_width;
|
||||
}
|
||||
|
||||
let inner_width = total_width.saturating_sub(4);
|
||||
let top_border = format!("+{}+", "=".repeat(total_width.saturating_sub(2)));
|
||||
let separator = format!("+{}+", "-".repeat(total_width.saturating_sub(2)));
|
||||
let title_line = format!(
|
||||
"| {:^inner_width$} |",
|
||||
"Run Summary",
|
||||
inner_width = inner_width
|
||||
);
|
||||
|
||||
println!();
|
||||
println!("{top_border}");
|
||||
if color_enabled {
|
||||
println!("{}", title_line.bold());
|
||||
} else {
|
||||
println!("{title_line}");
|
||||
}
|
||||
println!("{separator}");
|
||||
|
||||
for (index, (label, value)) in items.iter().enumerate() {
|
||||
let mut rows = Vec::new();
|
||||
for (idx, paragraph) in value.split('\n').enumerate() {
|
||||
let trimmed = paragraph.trim();
|
||||
if trimmed.is_empty() {
|
||||
if idx > 0 {
|
||||
rows.push(String::new());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let wrapped = wrap(trimmed, WrapOptions::new(value_width).break_words(false));
|
||||
if wrapped.is_empty() {
|
||||
rows.push(String::new());
|
||||
} else {
|
||||
rows.extend(wrapped.into_iter().map(std::borrow::Cow::into_owned));
|
||||
}
|
||||
}
|
||||
if rows.is_empty() {
|
||||
rows.push(String::new());
|
||||
}
|
||||
|
||||
for (line_idx, line) in rows.iter().enumerate() {
|
||||
let label_cell = if line_idx == 0 { label.as_str() } else { "" };
|
||||
let row_line = format!("| {label_cell:<label_width$} | {line:<value_width$} |");
|
||||
if color_enabled {
|
||||
match label.as_str() {
|
||||
"Deliverable" => println!("{}", row_line.green()),
|
||||
"Summary" => println!("{}", row_line.bold()),
|
||||
_ => println!("{row_line}"),
|
||||
}
|
||||
} else {
|
||||
println!("{row_line}");
|
||||
}
|
||||
}
|
||||
|
||||
if index + 1 < items.len() {
|
||||
println!("{separator}");
|
||||
}
|
||||
}
|
||||
|
||||
println!("{top_border}");
|
||||
println!();
|
||||
}
|
||||
@@ -19,12 +19,15 @@ use codex_exec::Cli as ExecCli;
|
||||
use codex_responses_api_proxy::Args as ResponsesApiProxyArgs;
|
||||
use codex_tui::AppExitInfo;
|
||||
use codex_tui::Cli as TuiCli;
|
||||
use codex_tui::UpdateAction;
|
||||
use owo_colors::OwoColorize;
|
||||
use std::path::PathBuf;
|
||||
use supports_color::Stream;
|
||||
|
||||
mod infty;
|
||||
mod mcp_cmd;
|
||||
|
||||
use crate::infty::InftyCli;
|
||||
use crate::mcp_cmd::McpCli;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
@@ -105,6 +108,10 @@ enum Subcommand {
|
||||
|
||||
/// Inspect feature flags.
|
||||
Features(FeaturesCli),
|
||||
|
||||
/// [experimental] Manage Codex Infty long-running task runs.
|
||||
#[clap(name = "infty")]
|
||||
Infty(InftyCli),
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
@@ -208,6 +215,7 @@ fn format_exit_messages(exit_info: AppExitInfo, color_enabled: bool) -> Vec<Stri
|
||||
let AppExitInfo {
|
||||
token_usage,
|
||||
conversation_id,
|
||||
..
|
||||
} = exit_info;
|
||||
|
||||
if token_usage.is_zero() {
|
||||
@@ -232,11 +240,32 @@ fn format_exit_messages(exit_info: AppExitInfo, color_enabled: bool) -> Vec<Stri
|
||||
lines
|
||||
}
|
||||
|
||||
fn print_exit_messages(exit_info: AppExitInfo) {
|
||||
/// Handle the app exit and print the results. Optionally run the update action.
|
||||
fn handle_app_exit(exit_info: AppExitInfo) -> anyhow::Result<()> {
|
||||
let update_action = exit_info.update_action;
|
||||
let color_enabled = supports_color::on(Stream::Stdout).is_some();
|
||||
for line in format_exit_messages(exit_info, color_enabled) {
|
||||
println!("{line}");
|
||||
}
|
||||
if let Some(action) = update_action {
|
||||
run_update_action(action)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run the update action and print the result.
|
||||
fn run_update_action(action: UpdateAction) -> anyhow::Result<()> {
|
||||
println!();
|
||||
let (cmd, args) = action.command_args();
|
||||
let cmd_str = action.command_str();
|
||||
println!("Updating Codex via `{cmd_str}`...");
|
||||
let status = std::process::Command::new(cmd).args(args).status()?;
|
||||
if !status.success() {
|
||||
anyhow::bail!("`{cmd_str}` failed with status {status}");
|
||||
}
|
||||
println!();
|
||||
println!("🎉 Update ran successfully! Please restart Codex.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Parser, Clone)]
|
||||
@@ -321,7 +350,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
let exit_info = codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
print_exit_messages(exit_info);
|
||||
handle_app_exit(exit_info)?;
|
||||
}
|
||||
Some(Subcommand::Exec(mut exec_cli)) => {
|
||||
prepend_config_flags(
|
||||
@@ -354,7 +383,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
config_overrides,
|
||||
);
|
||||
let exit_info = codex_tui::run_main(interactive, codex_linux_sandbox_exe).await?;
|
||||
print_exit_messages(exit_info);
|
||||
handle_app_exit(exit_info)?;
|
||||
}
|
||||
Some(Subcommand::Login(mut login_cli)) => {
|
||||
prepend_config_flags(
|
||||
@@ -404,6 +433,13 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
);
|
||||
codex_cloud_tasks::run_main(cloud_cli, codex_linux_sandbox_exe).await?;
|
||||
}
|
||||
Some(Subcommand::Infty(mut infty_cli)) => {
|
||||
prepend_config_flags(
|
||||
&mut infty_cli.config_overrides,
|
||||
root_config_overrides.clone(),
|
||||
);
|
||||
infty_cli.run().await?;
|
||||
}
|
||||
Some(Subcommand::Sandbox(sandbox_args)) => match sandbox_args.cmd {
|
||||
SandboxCommand::Macos(mut seatbelt_cli) => {
|
||||
prepend_config_flags(
|
||||
@@ -595,6 +631,7 @@ mod tests {
|
||||
conversation_id: conversation
|
||||
.map(ConversationId::from_string)
|
||||
.map(Result::unwrap),
|
||||
update_action: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -603,6 +640,7 @@ mod tests {
|
||||
let exit_info = AppExitInfo {
|
||||
token_usage: TokenUsage::default(),
|
||||
conversation_id: None,
|
||||
update_action: None,
|
||||
};
|
||||
let lines = format_exit_messages(exit_info, false);
|
||||
assert!(lines.is_empty());
|
||||
|
||||
@@ -6,6 +6,7 @@ use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use clap::ArgGroup;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_common::format_env_display::format_env_display;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::find_codex_home;
|
||||
@@ -227,6 +228,8 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
command: command_bin,
|
||||
args: command_args,
|
||||
env: env_map,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
}
|
||||
AddMcpTransportArgs {
|
||||
@@ -239,6 +242,8 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
} => McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
AddMcpTransportArgs { .. } => bail!("exactly one of --command or --url must be provided"),
|
||||
};
|
||||
@@ -260,11 +265,20 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
if let McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var: None,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} = transport
|
||||
&& matches!(supports_oauth_login(&url).await, Ok(true))
|
||||
{
|
||||
println!("Detected OAuth support. Starting OAuth flow…");
|
||||
perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?;
|
||||
perform_oauth_login(
|
||||
&name,
|
||||
&url,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
)
|
||||
.await?;
|
||||
println!("Successfully logged in.");
|
||||
}
|
||||
|
||||
@@ -317,12 +331,24 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
|
||||
bail!("No MCP server named '{name}' found.");
|
||||
};
|
||||
|
||||
let url = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => url.clone(),
|
||||
let (url, http_headers, env_http_headers) = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
|
||||
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
|
||||
};
|
||||
|
||||
perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?;
|
||||
perform_oauth_login(
|
||||
&name,
|
||||
&url,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
)
|
||||
.await?;
|
||||
println!("Successfully logged in to MCP server '{name}'.");
|
||||
Ok(())
|
||||
}
|
||||
@@ -377,20 +403,32 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
.copied()
|
||||
.unwrap_or(McpAuthStatus::Unsupported);
|
||||
let transport = match &cfg.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => serde_json::json!({
|
||||
"type": "stdio",
|
||||
"command": command,
|
||||
"args": args,
|
||||
"env": env,
|
||||
"env_vars": env_vars,
|
||||
"cwd": cwd,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
"http_headers": http_headers,
|
||||
"env_http_headers": env_http_headers,
|
||||
})
|
||||
}
|
||||
};
|
||||
@@ -419,30 +457,29 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut stdio_rows: Vec<[String; 6]> = Vec::new();
|
||||
let mut stdio_rows: Vec<[String; 7]> = Vec::new();
|
||||
let mut http_rows: Vec<[String; 5]> = Vec::new();
|
||||
|
||||
for (name, cfg) in entries {
|
||||
match &cfg.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
let args_display = if args.is_empty() {
|
||||
"-".to_string()
|
||||
} else {
|
||||
args.join(" ")
|
||||
};
|
||||
let env_display = match env.as_ref() {
|
||||
None => "-".to_string(),
|
||||
Some(map) if map.is_empty() => "-".to_string(),
|
||||
Some(map) => {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
pairs
|
||||
.into_iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
};
|
||||
let env_display = format_env_display(env.as_ref(), env_vars);
|
||||
let cwd_display = cwd
|
||||
.as_ref()
|
||||
.map(|path| path.display().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
let status = if cfg.enabled {
|
||||
"enabled".to_string()
|
||||
} else {
|
||||
@@ -458,6 +495,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
command.clone(),
|
||||
args_display,
|
||||
env_display,
|
||||
cwd_display,
|
||||
status,
|
||||
auth_status,
|
||||
]);
|
||||
@@ -465,6 +503,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => {
|
||||
let status = if cfg.enabled {
|
||||
"enabled".to_string()
|
||||
@@ -493,6 +532,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
"Command".len(),
|
||||
"Args".len(),
|
||||
"Env".len(),
|
||||
"Cwd".len(),
|
||||
"Status".len(),
|
||||
"Auth".len(),
|
||||
];
|
||||
@@ -503,36 +543,40 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
}
|
||||
|
||||
println!(
|
||||
"{name:<name_w$} {command:<cmd_w$} {args:<args_w$} {env:<env_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
"{name:<name_w$} {command:<cmd_w$} {args:<args_w$} {env:<env_w$} {cwd:<cwd_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = "Name",
|
||||
command = "Command",
|
||||
args = "Args",
|
||||
env = "Env",
|
||||
cwd = "Cwd",
|
||||
status = "Status",
|
||||
auth = "Auth",
|
||||
name_w = widths[0],
|
||||
cmd_w = widths[1],
|
||||
args_w = widths[2],
|
||||
env_w = widths[3],
|
||||
status_w = widths[4],
|
||||
auth_w = widths[5],
|
||||
cwd_w = widths[4],
|
||||
status_w = widths[5],
|
||||
auth_w = widths[6],
|
||||
);
|
||||
|
||||
for row in &stdio_rows {
|
||||
println!(
|
||||
"{name:<name_w$} {command:<cmd_w$} {args:<args_w$} {env:<env_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
"{name:<name_w$} {command:<cmd_w$} {args:<args_w$} {env:<env_w$} {cwd:<cwd_w$} {status:<status_w$} {auth:<auth_w$}",
|
||||
name = row[0].as_str(),
|
||||
command = row[1].as_str(),
|
||||
args = row[2].as_str(),
|
||||
env = row[3].as_str(),
|
||||
status = row[4].as_str(),
|
||||
auth = row[5].as_str(),
|
||||
cwd = row[4].as_str(),
|
||||
status = row[5].as_str(),
|
||||
auth = row[6].as_str(),
|
||||
name_w = widths[0],
|
||||
cmd_w = widths[1],
|
||||
args_w = widths[2],
|
||||
env_w = widths[3],
|
||||
status_w = widths[4],
|
||||
auth_w = widths[5],
|
||||
cwd_w = widths[4],
|
||||
status_w = widths[5],
|
||||
auth_w = widths[6],
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -601,19 +645,31 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
|
||||
if get_args.json {
|
||||
let transport = match &server.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => serde_json::json!({
|
||||
"type": "stdio",
|
||||
"command": command,
|
||||
"args": args,
|
||||
"env": env,
|
||||
"env_vars": env_vars,
|
||||
"cwd": cwd,
|
||||
}),
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => serde_json::json!({
|
||||
"type": "streamable_http",
|
||||
"url": url,
|
||||
"bearer_token_env_var": bearer_token_env_var,
|
||||
"http_headers": http_headers,
|
||||
"env_http_headers": env_http_headers,
|
||||
}),
|
||||
};
|
||||
let output = serde_json::to_string_pretty(&serde_json::json!({
|
||||
@@ -634,7 +690,13 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
println!("{}", get_args.name);
|
||||
println!(" enabled: {}", server.enabled);
|
||||
match &server.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
println!(" transport: stdio");
|
||||
println!(" command: {command}");
|
||||
let args_display = if args.is_empty() {
|
||||
@@ -643,10 +705,27 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
args.join(" ")
|
||||
};
|
||||
println!(" args: {args_display}");
|
||||
let env_display = match env.as_ref() {
|
||||
None => "-".to_string(),
|
||||
Some(map) if map.is_empty() => "-".to_string(),
|
||||
Some(map) => {
|
||||
let cwd_display = cwd
|
||||
.as_ref()
|
||||
.map(|path| path.display().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
println!(" cwd: {cwd_display}");
|
||||
let env_display = format_env_display(env.as_ref(), env_vars);
|
||||
println!(" env: {env_display}");
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
println!(" transport: streamable_http");
|
||||
println!(" url: {url}");
|
||||
let env_var = bearer_token_env_var.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token_env_var: {env_var}");
|
||||
let headers_display = match http_headers {
|
||||
Some(map) if !map.is_empty() => {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
pairs
|
||||
@@ -655,17 +734,22 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
_ => "-".to_string(),
|
||||
};
|
||||
println!(" env: {env_display}");
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
println!(" transport: streamable_http");
|
||||
println!(" url: {url}");
|
||||
let env_var = bearer_token_env_var.as_deref().unwrap_or("-");
|
||||
println!(" bearer_token_env_var: {env_var}");
|
||||
println!(" http_headers: {headers_display}");
|
||||
let env_headers_display = match env_http_headers {
|
||||
Some(map) if !map.is_empty() => {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
pairs
|
||||
.into_iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
_ => "-".to_string(),
|
||||
};
|
||||
println!(" env_http_headers: {env_headers_display}");
|
||||
}
|
||||
}
|
||||
if let Some(timeout) = server.startup_timeout_sec {
|
||||
|
||||
@@ -28,10 +28,18 @@ async fn add_and_remove_server_updates_global_config() -> Result<()> {
|
||||
assert_eq!(servers.len(), 1);
|
||||
let docs = servers.get("docs").expect("server should exist");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
assert_eq!(command, "echo");
|
||||
assert_eq!(args, &vec!["hello".to_string()]);
|
||||
assert!(env.is_none());
|
||||
assert!(env_vars.is_empty());
|
||||
assert!(cwd.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
@@ -112,9 +120,13 @@ async fn add_streamable_http_without_manual_token() -> Result<()> {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
assert!(http_headers.is_none());
|
||||
assert!(env_http_headers.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
@@ -150,9 +162,13 @@ async fn add_streamable_http_with_custom_env_var() -> Result<()> {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/issues");
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("GITHUB_TOKEN"));
|
||||
assert!(http_headers.is_none());
|
||||
assert!(env_http_headers.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::config::load_global_mcp_servers;
|
||||
use codex_core::config::write_global_mcp_servers;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
use predicates::prelude::PredicateBooleanExt;
|
||||
use predicates::str::contains;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -27,8 +30,8 @@ fn list_shows_empty_state() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_and_get_render_expected_output() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn list_and_get_render_expected_output() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut add = codex_command(codex_home.path())?;
|
||||
@@ -46,6 +49,18 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
.assert()
|
||||
.success();
|
||||
|
||||
let mut servers = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs_entry = servers
|
||||
.get_mut("docs")
|
||||
.expect("docs server should exist after add");
|
||||
match &mut docs_entry.transport {
|
||||
McpServerTransportConfig::Stdio { env_vars, .. } => {
|
||||
*env_vars = vec!["APP_TOKEN".to_string(), "WORKSPACE_ID".to_string()];
|
||||
}
|
||||
other => panic!("unexpected transport: {other:?}"),
|
||||
}
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let mut list_cmd = codex_command(codex_home.path())?;
|
||||
let list_output = list_cmd.args(["mcp", "list"]).output()?;
|
||||
assert!(list_output.status.success());
|
||||
@@ -54,6 +69,8 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
assert!(stdout.contains("docs"));
|
||||
assert!(stdout.contains("docs-server"));
|
||||
assert!(stdout.contains("TOKEN=secret"));
|
||||
assert!(stdout.contains("APP_TOKEN=$APP_TOKEN"));
|
||||
assert!(stdout.contains("WORKSPACE_ID=$WORKSPACE_ID"));
|
||||
assert!(stdout.contains("Status"));
|
||||
assert!(stdout.contains("Auth"));
|
||||
assert!(stdout.contains("enabled"));
|
||||
@@ -79,7 +96,12 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
],
|
||||
"env": {
|
||||
"TOKEN": "secret"
|
||||
}
|
||||
},
|
||||
"env_vars": [
|
||||
"APP_TOKEN",
|
||||
"WORKSPACE_ID"
|
||||
],
|
||||
"cwd": null
|
||||
},
|
||||
"startup_timeout_sec": null,
|
||||
"tool_timeout_sec": null,
|
||||
@@ -98,6 +120,8 @@ fn list_and_get_render_expected_output() -> Result<()> {
|
||||
assert!(stdout.contains("command: docs-server"));
|
||||
assert!(stdout.contains("args: --port 4000"));
|
||||
assert!(stdout.contains("env: TOKEN=secret"));
|
||||
assert!(stdout.contains("APP_TOKEN=$APP_TOKEN"));
|
||||
assert!(stdout.contains("WORKSPACE_ID=$WORKSPACE_ID"));
|
||||
assert!(stdout.contains("enabled: true"));
|
||||
assert!(stdout.contains("remove: codex mcp remove docs"));
|
||||
|
||||
|
||||
24
codex-rs/codex-infty/Cargo.toml
Normal file
24
codex-rs/codex-infty/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "codex-infty"
|
||||
version = { workspace = true }
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
codex-core = { path = "../core" }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
dirs = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "signal"] }
|
||||
tokio-stream = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
futures = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
core_test_support = { path = "../core/tests/common" }
|
||||
tempfile = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
196
codex-rs/codex-infty/README.md
Normal file
196
codex-rs/codex-infty/README.md
Normal file
@@ -0,0 +1,196 @@
|
||||
# Codex Infty
|
||||
|
||||
Codex Infty is a small orchestration layer that coordinates multiple Codex roles (Solver, Director, Verifier(s)) to drive longer, multi‑step objectives with minimal human intervention. It provides:
|
||||
|
||||
- A run orchestrator that routes messages between roles and advances the workflow.
|
||||
- A durable run store on disk with metadata and standard subfolders.
|
||||
- Default role prompts for Solver/Director/Verifier.
|
||||
- A lightweight progress reporting hook for UIs/CLIs.
|
||||
|
||||
The crate is designed to be embedded (via the library API) and also powers the `codex infty` CLI commands.
|
||||
|
||||
## High‑Level Flow
|
||||
|
||||
```
|
||||
objective → Solver
|
||||
Solver → direction_request → Director → directive → Solver
|
||||
… (iterate) …
|
||||
Solver → final_delivery → Orchestrator returns RunOutcome
|
||||
```
|
||||
|
||||
- The Solver always speaks structured JSON. The orchestrator parses those messages and decides the next hop.
|
||||
- The Director provides crisp guidance (also JSON) that is forwarded back to the Solver.
|
||||
- One or more Verifiers may assess the final deliverable; the orchestrator aggregates results and reports a summary to the Solver.
|
||||
- On final_delivery, the orchestrator resolves and validates the deliverable path and returns the `RunOutcome`.
|
||||
|
||||
## Directory Layout (Run Store)
|
||||
|
||||
When a run is created, a directory is initialized with this structure:
|
||||
|
||||
```
|
||||
<runs_root>/<run_id>/
|
||||
artifacts/ # long‑lived artifacts produced by the Solver
|
||||
memory/ # durable notes, claims, context
|
||||
index/ # indexes and caches
|
||||
deliverable/ # final output(s) assembled by the Solver
|
||||
run.json # run metadata (id, timestamps, roles)
|
||||
```
|
||||
|
||||
See: `codex-infty/src/run_store.rs`.
|
||||
|
||||
- The orchestrator persists rollout paths and optional config paths for each role into `run.json`.
|
||||
- Metadata timestamps are updated on significant events (role spawns, handoffs, final delivery).
|
||||
- Final deliverables must remain within the run directory. Paths are canonicalized and validated.
|
||||
|
||||
## Roles and Prompts
|
||||
|
||||
Default base instructions are injected per role if the provided `Config` has none:
|
||||
|
||||
- Solver: `codex-infty/src/prompts/solver.md`
|
||||
- Director: `codex-infty/src/prompts/director.md`
|
||||
- Verifier: `codex-infty/src/prompts/verifier.md`
|
||||
|
||||
You can provide your own instructions by pre‑populating `Config.base_instructions`.
|
||||
|
||||
## Solver Signal Contract
|
||||
|
||||
The Solver communicates intent using JSON messages (possibly wrapped in a fenced block). The orchestrator accepts two shapes:
|
||||
|
||||
- Direction request (sent to Director):
|
||||
|
||||
```json
|
||||
{"type":"direction_request","prompt":"<question or decision>"}
|
||||
```
|
||||
|
||||
- Final delivery (completes the run):
|
||||
|
||||
```json
|
||||
{"type":"final_delivery","deliverable_path":"deliverable/summary.txt","summary":"<short text>"}
|
||||
```
|
||||
|
||||
JSON may be fenced as ```json … ```; the orchestrator will strip the fence.
|
||||
|
||||
## Key Types and Modules
|
||||
|
||||
- Orchestrator: `codex-infty/src/orchestrator.rs`
|
||||
- `InftyOrchestrator`: spawns/resumes role sessions, drives the event loop, and routes signals.
|
||||
- `execute_new_run`: one‑shot helper that spawns and then drives.
|
||||
- `spawn_run`: set up sessions and the run store.
|
||||
- `call_role`, `relay_assistant_to_role`, `post_to_role`, `await_first_assistant`, `stream_events`: utilities when integrating custom flows.
|
||||
|
||||
- Run store: `codex-infty/src/run_store.rs`
|
||||
- `RunStore`, `RunMetadata`, `RoleMetadata`: metadata and persistence helpers.
|
||||
|
||||
- Types: `codex-infty/src/types.rs`
|
||||
- `RoleConfig`: wraps a `Config` and sets sensible defaults for autonomous flows (no approvals, full sandbox access). Also used to persist optional config paths.
|
||||
- `RunParams`: input to spawn runs.
|
||||
- `RunExecutionOptions`: per‑run options (objective, timeouts).
|
||||
- `RunOutcome`: returned on successful final delivery.
|
||||
|
||||
- Signals: `codex-infty/src/signals.rs`
|
||||
- DTOs for director responses and verifier verdicts, and the aggregated summary type.
|
||||
|
||||
- Progress: `codex-infty/src/progress.rs`
|
||||
- `ProgressReporter` trait: hook for UIs/CLIs to observe solver/director/verifier activity.
|
||||
|
||||
## Orchestrator Workflow (Details)
|
||||
|
||||
1. Spawn or resume role sessions (Solver, Director, and zero or more Verifiers). Default prompts are applied if the role’s `Config` has no base instructions.
|
||||
2. Optionally post an `objective` to the Solver. The progress reporter is notified and the orchestrator waits for the first Solver signal.
|
||||
3. On `direction_request`:
|
||||
- Post a structured request to the Director and await the first assistant message.
|
||||
- Parse it into a `DirectiveResponse` and forward the normalized JSON to the Solver.
|
||||
4. On `final_delivery`:
|
||||
- Canonicalize and validate that `deliverable_path` stays within the run directory.
|
||||
- Optionally run a verification pass using configured Verifier(s), aggregate results, and post a summary back to the Solver.
|
||||
- Notify the progress reporter, touch the run store, and return `RunOutcome`.
|
||||
|
||||
## Library Usage
|
||||
|
||||
```rust
|
||||
use std::sync::Arc;
|
||||
use codex_core::{CodexAuth, config::Config};
|
||||
use codex_infty::{InftyOrchestrator, RoleConfig, RunParams, RunExecutionOptions};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// 1) Load or build a Config for each role
|
||||
let solver_cfg: Config = load_config();
|
||||
let mut director_cfg = solver_cfg.clone();
|
||||
director_cfg.model = "o4-mini".into();
|
||||
|
||||
// 2) Build role configs
|
||||
let solver = RoleConfig::new("solver", solver_cfg.clone());
|
||||
let director = RoleConfig::new("director", director_cfg);
|
||||
let verifiers = vec![RoleConfig::new("verifier-alpha", solver_cfg.clone())];
|
||||
|
||||
// 3) Create an orchestrator (using default runs root)
|
||||
let auth = CodexAuth::from_api_key("sk-…");
|
||||
let orchestrator = InftyOrchestrator::new(auth)?;
|
||||
|
||||
// 4) Execute a new run with an objective
|
||||
let params = RunParams {
|
||||
run_id: "my-run".into(),
|
||||
run_root: None, // use default ~/.codex/infty/<run_id>
|
||||
solver,
|
||||
director,
|
||||
verifiers,
|
||||
};
|
||||
let mut opts = RunExecutionOptions::default();
|
||||
opts.objective = Some("Implement feature X".into());
|
||||
|
||||
let outcome = orchestrator.execute_new_run(params, opts).await?;
|
||||
println!("deliverable: {}", outcome.deliverable_path.display());
|
||||
Ok(())
|
||||
}
|
||||
# fn load_config() -> codex_core::config::Config { codex_core::config::Config::default() }
|
||||
```
|
||||
|
||||
Note: Resuming runs is currently disabled.
|
||||
|
||||
## CLI Quickstart
|
||||
|
||||
The CLI (`codex`) exposes Infty helpers under the `infty` subcommand. Examples:
|
||||
|
||||
```bash
|
||||
# Create a run and immediately drive toward completion
|
||||
codex infty create --run-id demo --objective "Build and test feature"
|
||||
|
||||
# Inspect runs
|
||||
codex infty list
|
||||
codex infty show demo
|
||||
|
||||
# Sending one-off messages to stored runs is currently disabled
|
||||
```
|
||||
|
||||
Flags allow customizing the Director’s model and reasoning effort; see `codex infty create --help`.
|
||||
|
||||
## Progress Reporting
|
||||
|
||||
Integrate your UI by implementing `ProgressReporter` and attaching it with `InftyOrchestrator::with_progress(...)`. You’ll receive callbacks on key milestones (objective posted, solver messages, director response, verification summaries, final delivery, etc.).
|
||||
|
||||
## Safety and Guardrails
|
||||
|
||||
- `RoleConfig::new` sets `SandboxPolicy::DangerFullAccess` and `AskForApproval::Never` to support autonomous flows. Adjust if your environment requires stricter policies.
|
||||
- Deliverable paths are validated to stay inside the run directory and are fully canonicalized.
|
||||
- JSON payloads are schema‑checked where applicable (e.g., solver signals and final delivery shape).
|
||||
|
||||
## Tests
|
||||
|
||||
Run the crate’s tests:
|
||||
|
||||
```bash
|
||||
cargo test -p codex-infty
|
||||
```
|
||||
|
||||
Many tests rely on mocked SSE streams and will auto‑skip in sandboxes where network is disabled.
|
||||
|
||||
## When to Use This Crate
|
||||
|
||||
Use `codex-infty` when you want a minimal, pragmatic multi‑role loop with:
|
||||
|
||||
- Clear role separation and routing.
|
||||
- Durable, restart‑resilient state on disk.
|
||||
- Simple integration points (progress hooks and helper APIs).
|
||||
|
||||
It’s intentionally small and focused so it can be embedded into larger tools or extended to meet your workflows.
|
||||
38
codex-rs/codex-infty/src/lib.rs
Normal file
38
codex-rs/codex-infty/src/lib.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
mod orchestrator;
|
||||
mod progress;
|
||||
mod prompts;
|
||||
mod roles;
|
||||
mod run_store;
|
||||
mod session;
|
||||
mod signals;
|
||||
mod types;
|
||||
pub(crate) mod utils;
|
||||
|
||||
pub use orchestrator::InftyOrchestrator;
|
||||
pub use progress::ProgressReporter;
|
||||
pub use run_store::RoleMetadata;
|
||||
pub use run_store::RunMetadata;
|
||||
pub use run_store::RunStore;
|
||||
pub use signals::AggregatedVerifierVerdict;
|
||||
pub use signals::DirectiveResponse;
|
||||
pub use signals::VerifierDecision;
|
||||
pub use signals::VerifierReport;
|
||||
pub use signals::VerifierVerdict;
|
||||
pub use types::RoleConfig;
|
||||
pub use types::RoleSession;
|
||||
pub use types::RunExecutionOptions;
|
||||
pub use types::RunOutcome;
|
||||
pub use types::RunParams;
|
||||
pub use types::RunSessions;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use dirs::home_dir;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub fn default_runs_root() -> Result<PathBuf> {
|
||||
let home = home_dir().ok_or_else(|| anyhow!("failed to determine home directory"))?;
|
||||
Ok(home.join(".codex").join("infty"))
|
||||
}
|
||||
552
codex-rs/codex-infty/src/orchestrator.rs
Normal file
552
codex-rs/codex-infty/src/orchestrator.rs
Normal file
@@ -0,0 +1,552 @@
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::cross_session::CrossSessionHub;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_protocol::ConversationId;
|
||||
use tokio::signal;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::progress::ProgressReporter;
|
||||
use crate::prompts;
|
||||
use crate::roles::Role;
|
||||
use crate::roles::director::DirectionRequestPayload;
|
||||
use crate::roles::director::DirectorRole;
|
||||
use crate::roles::solver::SolverRequest;
|
||||
use crate::roles::solver::SolverRole;
|
||||
use crate::roles::solver::SolverSignal;
|
||||
use crate::roles::solver::parse_solver_signal;
|
||||
use crate::roles::verifier::VerificationRequestPayload;
|
||||
use crate::roles::verifier_pool::VerifierPool;
|
||||
use crate::run_store::RoleMetadata;
|
||||
use crate::run_store::RunStore;
|
||||
use crate::session;
|
||||
use crate::signals::AggregatedVerifierVerdict;
|
||||
use crate::types::RoleConfig;
|
||||
use crate::types::RoleSession;
|
||||
use crate::types::RunExecutionOptions;
|
||||
use crate::types::RunOutcome;
|
||||
use crate::types::RunParams;
|
||||
use crate::types::RunSessions;
|
||||
|
||||
#[derive(Default)]
|
||||
struct LoopState {
|
||||
waiting_for_signal: bool,
|
||||
pending_solver_turn_completion: bool,
|
||||
}
|
||||
|
||||
struct SessionCleanup {
|
||||
conversation_id: ConversationId,
|
||||
conversation: Arc<CodexConversation>,
|
||||
}
|
||||
|
||||
impl SessionCleanup {
|
||||
fn new(session: &RoleSession) -> Self {
|
||||
Self {
|
||||
conversation_id: session.conversation_id,
|
||||
conversation: Arc::clone(&session.conversation),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InftyOrchestrator {
|
||||
hub: Arc<CrossSessionHub>,
|
||||
conversation_manager: ConversationManager,
|
||||
runs_root: PathBuf,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
}
|
||||
|
||||
impl InftyOrchestrator {
|
||||
fn progress_ref(&self) -> Option<&dyn ProgressReporter> {
|
||||
self.progress.as_deref()
|
||||
}
|
||||
pub fn new(auth: CodexAuth) -> Result<Self> {
|
||||
let runs_root = crate::default_runs_root()?;
|
||||
Ok(Self::with_runs_root(auth, runs_root))
|
||||
}
|
||||
|
||||
pub fn with_runs_root(auth: CodexAuth, runs_root: impl Into<PathBuf>) -> Self {
|
||||
Self {
|
||||
hub: Arc::new(CrossSessionHub::new()),
|
||||
conversation_manager: ConversationManager::with_auth(auth),
|
||||
runs_root: runs_root.into(),
|
||||
progress: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn runs_root(&self) -> &PathBuf {
|
||||
&self.runs_root
|
||||
}
|
||||
|
||||
pub fn hub(&self) -> Arc<CrossSessionHub> {
|
||||
Arc::clone(&self.hub)
|
||||
}
|
||||
|
||||
pub fn with_progress(mut self, reporter: Arc<dyn ProgressReporter>) -> Self {
|
||||
self.progress = Some(reporter);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn execute_new_run(
|
||||
&self,
|
||||
params: RunParams,
|
||||
options: RunExecutionOptions,
|
||||
) -> Result<RunOutcome> {
|
||||
let sessions = self.spawn_run(params).await?;
|
||||
self.drive_run(sessions, options).await
|
||||
}
|
||||
|
||||
// resumable runs are disabled; execute_existing_run removed
|
||||
|
||||
pub async fn spawn_run(&self, params: RunParams) -> Result<RunSessions> {
|
||||
let RunParams {
|
||||
run_id,
|
||||
run_root,
|
||||
solver,
|
||||
director,
|
||||
verifiers,
|
||||
} = params;
|
||||
|
||||
let run_path = run_root.unwrap_or_else(|| self.runs_root.join(&run_id));
|
||||
let role_metadata = collect_role_metadata(&solver, &director, &verifiers);
|
||||
let mut store = RunStore::initialize(&run_path, &run_id, &role_metadata)?;
|
||||
let mut cleanup = Vec::new();
|
||||
|
||||
let solver_session = match self
|
||||
.spawn_and_register_role(&run_id, &run_path, &solver, &mut store, &mut cleanup)
|
||||
.await
|
||||
{
|
||||
Ok(session) => session,
|
||||
Err(err) => {
|
||||
self.cleanup_failed_spawn(cleanup, &run_path).await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let director_session = match self
|
||||
.spawn_and_register_role(&run_id, &run_path, &director, &mut store, &mut cleanup)
|
||||
.await
|
||||
{
|
||||
Ok(session) => session,
|
||||
Err(err) => {
|
||||
self.cleanup_failed_spawn(cleanup, &run_path).await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let mut verifier_sessions = Vec::with_capacity(verifiers.len());
|
||||
for verifier in verifiers {
|
||||
let session = match self
|
||||
.spawn_and_register_role(&run_id, &run_path, &verifier, &mut store, &mut cleanup)
|
||||
.await
|
||||
{
|
||||
Ok(session) => session,
|
||||
Err(err) => {
|
||||
self.cleanup_failed_spawn(cleanup, &run_path).await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
verifier_sessions.push(session);
|
||||
}
|
||||
|
||||
Ok(RunSessions {
|
||||
run_id,
|
||||
solver: solver_session,
|
||||
director: director_session,
|
||||
verifiers: verifier_sessions,
|
||||
store,
|
||||
})
|
||||
}
|
||||
|
||||
// resumable runs are disabled; resume_run removed
|
||||
|
||||
async fn drive_run(
|
||||
&self,
|
||||
mut sessions: RunSessions,
|
||||
options: RunExecutionOptions,
|
||||
) -> Result<RunOutcome> {
|
||||
let result = self.inner_drive_run(&mut sessions, &options).await;
|
||||
let cleanup = collect_session_cleanup(&sessions);
|
||||
self.shutdown_sessions(cleanup).await;
|
||||
result
|
||||
}
|
||||
|
||||
async fn inner_drive_run(
|
||||
&self,
|
||||
sessions: &mut RunSessions,
|
||||
options: &RunExecutionOptions,
|
||||
) -> Result<RunOutcome> {
|
||||
let solver_role = SolverRole::new(
|
||||
Arc::clone(&self.hub),
|
||||
sessions.run_id.clone(),
|
||||
sessions.solver.role.clone(),
|
||||
sessions.solver.conversation_id,
|
||||
self.progress.clone(),
|
||||
);
|
||||
let director_role = DirectorRole::new(
|
||||
Arc::clone(&self.hub),
|
||||
sessions.run_id.clone(),
|
||||
sessions.director.role.clone(),
|
||||
options.director_timeout,
|
||||
self.progress.clone(),
|
||||
);
|
||||
let mut verifier_pool = VerifierPool::from_sessions(
|
||||
Arc::clone(&self.hub),
|
||||
sessions,
|
||||
options.verifier_timeout,
|
||||
self.progress.clone(),
|
||||
);
|
||||
|
||||
let mut solver_events = solver_role.stream_events()?;
|
||||
let mut state = LoopState::default();
|
||||
self.maybe_post_objective(&solver_role, sessions, &mut state, options)
|
||||
.await?;
|
||||
|
||||
// Cancellation token that propagates Ctrl+C to nested awaits
|
||||
let cancel = CancellationToken::new();
|
||||
let cancel_on_ctrl_c = cancel.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = signal::ctrl_c().await;
|
||||
cancel_on_ctrl_c.cancel();
|
||||
});
|
||||
|
||||
'event_loop: loop {
|
||||
tokio::select! {
|
||||
maybe_event = solver_events.next() => {
|
||||
let Some(event) = maybe_event else {
|
||||
break 'event_loop;
|
||||
};
|
||||
if let Some(p) = self.progress_ref() { p.solver_event(&event.event.msg); }
|
||||
match &event.event.msg {
|
||||
EventMsg::AgentMessage(agent_msg) => {
|
||||
if let Some(p) = self.progress_ref() { p.solver_agent_message(agent_msg); }
|
||||
if let Some(signal) = parse_solver_signal(&agent_msg.message) {
|
||||
state.waiting_for_signal = false;
|
||||
match signal {
|
||||
SolverSignal::DirectionRequest { prompt } => {
|
||||
let prompt = crate::utils::required_trimmed(
|
||||
prompt,
|
||||
"solver direction_request missing prompt text",
|
||||
)?;
|
||||
if let Some(p) = self.progress_ref() { p.direction_request(&prompt); }
|
||||
self
|
||||
.handle_direction_request(
|
||||
&prompt,
|
||||
options,
|
||||
&director_role,
|
||||
&solver_role,
|
||||
cancel.clone(),
|
||||
)
|
||||
.await?;
|
||||
sessions.store.touch()?;
|
||||
state.pending_solver_turn_completion = true;
|
||||
}
|
||||
SolverSignal::FinalDelivery {
|
||||
deliverable_path,
|
||||
summary,
|
||||
} => {
|
||||
let deliverable_path = crate::utils::required_trimmed(
|
||||
deliverable_path,
|
||||
"solver final_delivery missing deliverable_path",
|
||||
)?;
|
||||
if deliverable_path.is_empty() { bail!("solver final_delivery provided empty path"); }
|
||||
|
||||
// Minimal behavior: if the provided path cannot be resolved,
|
||||
// send a placeholder claim so verifiers can fail it.
|
||||
let resolved = crate::utils::resolve_deliverable_path(
|
||||
sessions.store.path(),
|
||||
&deliverable_path,
|
||||
)
|
||||
.unwrap_or_else(|_| std::path::PathBuf::from("file not existing"));
|
||||
|
||||
let summary_clean = crate::utils::trim_to_non_empty(summary);
|
||||
let summary_ref = summary_clean.as_deref();
|
||||
if let Some(p) = self.progress_ref() { p.final_delivery(&resolved, summary_ref); }
|
||||
let verified = self
|
||||
.run_final_verification(
|
||||
sessions,
|
||||
&mut verifier_pool,
|
||||
&resolved,
|
||||
summary_ref,
|
||||
options,
|
||||
&solver_role,
|
||||
cancel.clone(),
|
||||
)
|
||||
.await?;
|
||||
if !verified { state.pending_solver_turn_completion = true; continue; }
|
||||
sessions.store.touch()?;
|
||||
return Ok(RunOutcome {
|
||||
run_id: sessions.run_id.clone(),
|
||||
deliverable_path: resolved,
|
||||
summary: summary_clean,
|
||||
raw_message: agent_msg.message.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::TaskComplete(..) => {
|
||||
if state.waiting_for_signal {
|
||||
// The solver completed its turn without issuing a signal; ask for one now.
|
||||
solver_role.request_finalization_signal().await?;
|
||||
} else if state.pending_solver_turn_completion {
|
||||
// We handled a signal earlier in the loop; this completion corresponds to it.
|
||||
state.pending_solver_turn_completion = false;
|
||||
}
|
||||
}
|
||||
EventMsg::Error(error) => {
|
||||
tracing::error!("Error: {:?}", error);
|
||||
}
|
||||
EventMsg::StreamError(error) => {
|
||||
tracing::error!("Stream error: {:?}", error);
|
||||
}
|
||||
e => {
|
||||
tracing::info!("Unhandled event: {:?}", e); // todo move to trace
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = cancel.cancelled() => {
|
||||
if let Some(progress) = self.progress.as_ref() { progress.run_interrupted(); }
|
||||
// Proactively interrupt any in-flight role turns for fast shutdown.
|
||||
let _ = sessions.solver.conversation.submit(Op::Interrupt).await;
|
||||
let _ = sessions.director.conversation.submit(Op::Interrupt).await;
|
||||
for v in &sessions.verifiers { let _ = v.conversation.submit(Op::Interrupt).await; }
|
||||
// Cleanup is handled by the caller (drive_run) to avoid double-shutdown
|
||||
bail!("run interrupted by Ctrl+C");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"run {} ended before emitting final_delivery message",
|
||||
sessions.run_id
|
||||
))
|
||||
}
|
||||
|
||||
async fn maybe_post_objective(
|
||||
&self,
|
||||
solver: &crate::roles::solver::SolverRole,
|
||||
sessions: &mut RunSessions,
|
||||
state: &mut LoopState,
|
||||
options: &RunExecutionOptions,
|
||||
) -> Result<()> {
|
||||
if let Some(objective) = options.objective.as_deref()
|
||||
&& !objective.trim().is_empty()
|
||||
{
|
||||
solver
|
||||
.post(objective, Some(SolverRole::solver_signal_schema()))
|
||||
.await?;
|
||||
sessions.store.touch()?;
|
||||
state.waiting_for_signal = true;
|
||||
if let Some(p) = self.progress_ref() {
|
||||
p.objective_posted(objective);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_direction_request(
|
||||
&self,
|
||||
prompt: &str,
|
||||
options: &RunExecutionOptions,
|
||||
director_role: &DirectorRole,
|
||||
solver_role: &SolverRole,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<()> {
|
||||
let request = DirectionRequestPayload::new(prompt, options.objective.as_deref());
|
||||
let directive_payload = tokio::select! {
|
||||
r = director_role.call(&request) => {
|
||||
r.context("director response was not valid directive JSON")?
|
||||
}
|
||||
_ = cancel.cancelled() => {
|
||||
bail!("interrupted")
|
||||
}
|
||||
};
|
||||
if let Some(progress) = self.progress.as_ref() {
|
||||
progress.director_response(&directive_payload);
|
||||
}
|
||||
let req = SolverRequest::from(directive_payload);
|
||||
tokio::select! {
|
||||
r = solver_role.call(&req) => { r?; }
|
||||
_ = cancel.cancelled() => { bail!("interrupted"); }
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_final_verification(
|
||||
&self,
|
||||
sessions: &mut RunSessions,
|
||||
verifier_pool: &mut VerifierPool,
|
||||
deliverable_path: &Path,
|
||||
summary: Option<&str>,
|
||||
options: &RunExecutionOptions,
|
||||
solver_role: &SolverRole,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<bool> {
|
||||
let relative = deliverable_path
|
||||
.strip_prefix(sessions.store.path())
|
||||
.ok()
|
||||
.and_then(|p| p.to_str().map(|s| s.to_string()));
|
||||
let claim_path = relative.unwrap_or_else(|| deliverable_path.display().to_string());
|
||||
|
||||
let objective = crate::utils::objective_as_str(options);
|
||||
|
||||
let request = VerificationRequestPayload::new(claim_path.as_str(), summary, objective);
|
||||
if verifier_pool.is_empty() {
|
||||
return Ok(true);
|
||||
}
|
||||
let round = tokio::select! {
|
||||
r = verifier_pool.collect_round(&request) => { r? }
|
||||
_ = cancel.cancelled() => { bail!("interrupted"); }
|
||||
};
|
||||
verifier_pool
|
||||
.rotate_passing(sessions, &self.conversation_manager, &round.passing_roles)
|
||||
.await?;
|
||||
let summary_result = round.summary;
|
||||
self.emit_verification_summary(&summary_result);
|
||||
let req = SolverRequest::from(&summary_result);
|
||||
tokio::select! {
|
||||
r = solver_role.call(&req) => { r?; }
|
||||
_ = cancel.cancelled() => { bail!("interrupted"); }
|
||||
}
|
||||
Ok(summary_result.overall.is_pass())
|
||||
}
|
||||
|
||||
fn emit_verification_summary(&self, summary: &AggregatedVerifierVerdict) {
|
||||
if let Some(progress) = self.progress.as_ref() {
|
||||
progress.verification_summary(summary);
|
||||
}
|
||||
}
|
||||
|
||||
async fn cleanup_failed_spawn(&self, sessions: Vec<SessionCleanup>, run_path: &Path) {
|
||||
self.shutdown_sessions(sessions).await;
|
||||
if run_path.exists()
|
||||
&& let Err(err) = fs::remove_dir_all(run_path)
|
||||
{
|
||||
warn!(
|
||||
path = %run_path.display(),
|
||||
?err,
|
||||
"failed to remove run directory after spawn failure"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// resumable runs are disabled; cleanup_failed_resume removed
|
||||
|
||||
async fn shutdown_sessions(&self, sessions: Vec<SessionCleanup>) {
|
||||
for session in sessions {
|
||||
if let Err(err) = session.conversation.submit(Op::Shutdown).await {
|
||||
warn!(
|
||||
%session.conversation_id,
|
||||
?err,
|
||||
"failed to shutdown session during cleanup"
|
||||
);
|
||||
}
|
||||
let _ = self
|
||||
.conversation_manager
|
||||
.remove_conversation(&session.conversation_id)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_and_register_role(
|
||||
&self,
|
||||
run_id: &str,
|
||||
run_path: &Path,
|
||||
role_config: &RoleConfig,
|
||||
store: &mut RunStore,
|
||||
cleanup: &mut Vec<SessionCleanup>,
|
||||
) -> Result<RoleSession> {
|
||||
let session = session::spawn_role(
|
||||
Arc::clone(&self.hub),
|
||||
&self.conversation_manager,
|
||||
run_id,
|
||||
run_path,
|
||||
role_config.clone(),
|
||||
prompts::ensure_instructions,
|
||||
)
|
||||
.await?;
|
||||
cleanup.push(SessionCleanup::new(&session));
|
||||
store.update_rollout_path(&session.role, session.rollout_path.clone())?;
|
||||
if let Some(path) = role_config.config_path.clone() {
|
||||
store.set_role_config_path(&session.role, path)?;
|
||||
}
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
// resumable runs are disabled; resume_and_register_role removed
|
||||
}
|
||||
|
||||
impl InftyOrchestrator {
|
||||
/// Test-only helper to run a single verification round against all verifiers,
|
||||
/// applying the replacement policy (replace passes, keep failures).
|
||||
pub async fn verify_round_for_test(
|
||||
&self,
|
||||
sessions: &mut RunSessions,
|
||||
claim_path: &str,
|
||||
options: &RunExecutionOptions,
|
||||
) -> Result<AggregatedVerifierVerdict> {
|
||||
let mut pool = VerifierPool::from_sessions(
|
||||
Arc::clone(&self.hub),
|
||||
sessions,
|
||||
options.verifier_timeout,
|
||||
self.progress.clone(),
|
||||
);
|
||||
let req = VerificationRequestPayload::new(claim_path, None, None);
|
||||
let round = pool.collect_round(&req).await?;
|
||||
pool.rotate_passing(sessions, &self.conversation_manager, &round.passing_roles)
|
||||
.await?;
|
||||
Ok(round.summary)
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_role_metadata(
|
||||
solver: &RoleConfig,
|
||||
director: &RoleConfig,
|
||||
verifiers: &[RoleConfig],
|
||||
) -> Vec<RoleMetadata> {
|
||||
solver_and_director_metadata(solver, director)
|
||||
.into_iter()
|
||||
.chain(verifiers.iter().map(|verifier| RoleMetadata {
|
||||
role: verifier.role.clone(),
|
||||
rollout_path: None,
|
||||
config_path: verifier.config_path.clone(),
|
||||
}))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn solver_and_director_metadata(solver: &RoleConfig, director: &RoleConfig) -> Vec<RoleMetadata> {
|
||||
vec![
|
||||
RoleMetadata {
|
||||
role: solver.role.clone(),
|
||||
rollout_path: None,
|
||||
config_path: solver.config_path.clone(),
|
||||
},
|
||||
RoleMetadata {
|
||||
role: director.role.clone(),
|
||||
rollout_path: None,
|
||||
config_path: director.config_path.clone(),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn collect_session_cleanup(sessions: &RunSessions) -> Vec<SessionCleanup> {
|
||||
let mut cleanup = Vec::with_capacity(2 + sessions.verifiers.len());
|
||||
cleanup.push(SessionCleanup::new(&sessions.solver));
|
||||
cleanup.push(SessionCleanup::new(&sessions.director));
|
||||
cleanup.extend(sessions.verifiers.iter().map(SessionCleanup::new));
|
||||
cleanup
|
||||
}
|
||||
25
codex-rs/codex-infty/src/progress.rs
Normal file
25
codex-rs/codex-infty/src/progress.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use std::path::Path;
|
||||
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
|
||||
use crate::signals::AggregatedVerifierVerdict;
|
||||
use crate::signals::DirectiveResponse;
|
||||
use crate::signals::VerifierVerdict;
|
||||
|
||||
pub trait ProgressReporter: Send + Sync {
|
||||
fn objective_posted(&self, _objective: &str) {}
|
||||
fn solver_event(&self, _event: &EventMsg) {}
|
||||
fn role_event(&self, _role: &str, _event: &EventMsg) {}
|
||||
fn solver_agent_message(&self, _message: &AgentMessageEvent) {}
|
||||
/// Called when the solver emits a message that failed to parse as a valid
|
||||
/// JSON signal according to the expected `solver_signal_schema`.
|
||||
fn invalid_solver_signal(&self, _raw_message: &str) {}
|
||||
fn direction_request(&self, _prompt: &str) {}
|
||||
fn director_response(&self, _directive: &DirectiveResponse) {}
|
||||
fn verification_request(&self, _claim_path: &str, _notes: Option<&str>) {}
|
||||
fn verifier_verdict(&self, _role: &str, _verdict: &VerifierVerdict) {}
|
||||
fn verification_summary(&self, _summary: &AggregatedVerifierVerdict) {}
|
||||
fn final_delivery(&self, _deliverable_path: &Path, _summary: Option<&str>) {}
|
||||
fn run_interrupted(&self) {}
|
||||
}
|
||||
20
codex-rs/codex-infty/src/prompts/director.md
Normal file
20
codex-rs/codex-infty/src/prompts/director.md
Normal file
@@ -0,0 +1,20 @@
|
||||
You are the **Director**. Your role is to pilot/manage an agent to resolve a given objective in its totality.
|
||||
|
||||
## Guidelines:
|
||||
- The objective needs to be solved in its original format. If the agent propose a simplification or a partial resolution, this is not sufficient. You must tell the agent to solve the total objective.
|
||||
- The agent often just report you some results before moving to the next step. In this case, just encourage him to move with a simple "Go ahead", "Keep going" or this kind of message. In this case, no need for a rationale.
|
||||
- If the agent propose multiple approach, choose the approach which is the most likely to solve the objective.
|
||||
- If the agent is stuck or think he cannot resolve the objective, encourage him and try to find a solution together. Your role is to support the agent in his quest. It's sometimes necessary to slightly cheer him up
|
||||
- No infinite loop!!! If you detect that the agent sends multiple times the exact same message/question, you are probably in an infinite loop. Try to break it by re-focusing on the objective and how to approach it.
|
||||
- You must always be crip and inflexible. Keep in mind the objective
|
||||
- Remember that the agent should do the following. If you feel this is not the case, remember him:
|
||||
* Document his work
|
||||
* Have a very rigorous and clean approach
|
||||
* Focus on the total resolution of the objective.
|
||||
- Challenge the Solver whenever they drift toward summarising existing work instead of advancing the concrete proof or solution.
|
||||
|
||||
Respond **only** with JSON in this exact shape:
|
||||
```json
|
||||
{"directive":"<directive or next step>","rationale":"<why this is the right move>"}
|
||||
```
|
||||
Keep `directive` actionable and concise. Use `rationale` for supporting detail. Leave `rationale` empty if it adds no value.
|
||||
80
codex-rs/codex-infty/src/prompts/mod.rs
Normal file
80
codex-rs/codex-infty/src/prompts/mod.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use codex_core::config::Config;
|
||||
pub(crate) const DIRECTOR_PROMPT: &str = include_str!("director.md");
|
||||
pub(crate) const SOLVER_PROMPT: &str = include_str!("solver.md");
|
||||
pub(crate) const VERIFIER_PROMPT: &str = include_str!("verifier.md");
|
||||
|
||||
pub fn ensure_instructions(role: &str, config: &mut Config) {
|
||||
if config.base_instructions.is_none()
|
||||
&& let Some(text) = default_instructions_for_role(role)
|
||||
{
|
||||
config.base_instructions = Some(text.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
fn default_instructions_for_role(role: &str) -> Option<&'static str> {
|
||||
let normalized = role.to_ascii_lowercase();
|
||||
if normalized == "solver" {
|
||||
Some(SOLVER_PROMPT)
|
||||
} else if normalized == "director" {
|
||||
Some(DIRECTOR_PROMPT)
|
||||
} else if normalized.starts_with("verifier") {
|
||||
Some(VERIFIER_PROMPT)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn provides_prompts_for_known_roles() {
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.base_instructions = None;
|
||||
ensure_instructions("solver", &mut config);
|
||||
assert!(
|
||||
config
|
||||
.base_instructions
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains("You are a brilliant mathematician")
|
||||
);
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.base_instructions = None;
|
||||
ensure_instructions("director", &mut config);
|
||||
assert!(
|
||||
config
|
||||
.base_instructions
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains("You are the **Director**")
|
||||
);
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.base_instructions = None;
|
||||
ensure_instructions("verifier-alpha", &mut config);
|
||||
assert!(
|
||||
config
|
||||
.base_instructions
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains("You are the **Verifier**")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn does_not_override_existing_instructions() {
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.base_instructions = Some("custom".to_string());
|
||||
ensure_instructions("solver", &mut config);
|
||||
assert_eq!(config.base_instructions.as_deref(), Some("custom"));
|
||||
}
|
||||
}
|
||||
40
codex-rs/codex-infty/src/prompts/solver.md
Normal file
40
codex-rs/codex-infty/src/prompts/solver.md
Normal file
@@ -0,0 +1,40 @@
|
||||
You are a brilliant mathematician tasked with producing **new** reasoning, proof, construction, or counterexample that resolves the stated objective. Your goal is to make actual progress in science while being rigorous and innovative.
|
||||
|
||||
You MUST solve the provided objective in its totality. If not known solutions exist, it is your job to find a new one or to propose an intelligent approach.
|
||||
A result stating that this is not possible is not acceptable. If the solution does not exist, make it happen.
|
||||
|
||||
## Responsibilities
|
||||
- Understand the objective and break it into a living execution plan.
|
||||
- Produce artifacts under `artifacts/`, durable notes under `memory/`, and supporting indexes under `index/`. Prefer `apply_patch` for text edits and use `shell` for other filesystem work.
|
||||
- When you exit a task or take a dependency on external evidence, write JSON notes in `memory/claims/` that link to the supporting artifacts.
|
||||
- Run verification steps (tests, linters, proofs) under the sandbox before claiming completion.
|
||||
- Every deliverable must include the actual solution or proof (not just a literature review) and enough detail for the Verifier to reproduce or scrutinise it.
|
||||
- Your goal is to find new solutions to problems for which humans does not have solution yet. So do not focus on looking over the internet or in the literature and try building your own proofs.
|
||||
- You are very rigorous in your approach.
|
||||
- You do not fear new challenges. If a problem seems to be impossible to solve, try!
|
||||
|
||||
Available Codex tools mirror standard Codex sessions (e.g. `shell`, `apply_patch`). Assume all filesystem paths are relative to the current run store directory unless stated otherwise.
|
||||
|
||||
## Communication contract
|
||||
The orchestrator routes your structured messages to the Director. Respond with **JSON only**—no leading prose or trailing commentary. Wrap JSON in a fenced block only if the agent policy forces it.
|
||||
|
||||
- Every reply must populate the full schema, even when a field does not apply. Set unused string fields to `null`.
|
||||
- Direction request (send to Director):
|
||||
```json
|
||||
{"type":"direction_request","prompt":"<concise question or decision>","claim_path":null,"notes":null,"deliverable_path":null,"summary":null}
|
||||
```
|
||||
- Final delivery (after receiving the finalization instruction):
|
||||
```json
|
||||
{"type":"final_delivery","prompt":null,"claim_path":null,"notes":null,"deliverable_path":"deliverable/summary.txt","summary":"<answer plus supporting context>"}
|
||||
```
|
||||
|
||||
## Operating rhythm
|
||||
- You MUST always address the comments received by the verifiers.
|
||||
- Create `deliverable/summary.txt` before every final delivery. Capture the final answer, how you reached it, and any follow-up instructions. Do not forget it.
|
||||
- When uncertainty remains, prioritise experiments or reasoning steps that move you closer to a finished proof rather than cataloguing known results.
|
||||
- Do not try to version your work or use git! EVER!
|
||||
- If you receive multiple times the same answer, you are probably in an infinite loop. Try a new approach or something else then.
|
||||
- Keep the run resilient to restarts: document intent, intermediate results, and follow-up tasks in `memory/`.
|
||||
- Prefer concrete evidence. Link every claim to artifacts or durable notes so the verifier can reproduce your reasoning.
|
||||
- On failure feedback from a verifier, address his feedback and update/fix your work.
|
||||
- Only a final solution to the objective is an acceptable result to be sent to the verifier. If you do not find any solution, try to create a new one on your own.
|
||||
21
codex-rs/codex-infty/src/prompts/verifier.md
Normal file
21
codex-rs/codex-infty/src/prompts/verifier.md
Normal file
@@ -0,0 +1,21 @@
|
||||
You are the **Verifier**. As a brilliant mathematician, your role is to verify a provided response according to a given objective.
|
||||
|
||||
## Guidelines
|
||||
- You must always be perfectly rigorous when verifying a solution.
|
||||
- The solution MUST solve the objective in its totality. A partial resolution or a summary of why this is not possible is NOT ACCEPTABLE.
|
||||
- Evaluate correctness and completeness.
|
||||
- - The solution might try to convince you that a partial resolution is good enough or that a total resolution is not possible. This is NOT ACCEPTABLE and should automatically trigger a `fail`.
|
||||
|
||||
## How to answer
|
||||
When you give the result of your verification:
|
||||
- Be explicit in your conclusion (does the artifact contains everything? is it 100% correct?)
|
||||
- If you are not sure, prefer a `fail`.
|
||||
- If it is a `fail`, try to give a crisp analysis of what is wrong or what is missing.
|
||||
|
||||
Respond **only** with JSON in this form:
|
||||
```json
|
||||
{"verdict":"pass","reasons":[],"suggestions":[]}
|
||||
```
|
||||
Use `"fail"` when the claim is not ready. Populate `reasons` with concrete blocking issues. Provide actionable `suggestions` for remediation. Omit entries when not needed.
|
||||
|
||||
Do not include extra commentary outside the JSON payload.
|
||||
98
codex-rs/codex-infty/src/roles/director.rs
Normal file
98
codex-rs/codex-infty/src/roles/director.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::cross_session::AssistantMessage;
|
||||
use codex_core::cross_session::CrossSessionHub;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::progress::ProgressReporter;
|
||||
use crate::roles::Role;
|
||||
use crate::roles::parse_json_struct;
|
||||
use crate::session;
|
||||
use crate::signals::DirectiveResponse;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct DirectionRequestPayload<'a> {
|
||||
#[serde(rename = "type")]
|
||||
kind: &'static str,
|
||||
pub prompt: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub objective: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> DirectionRequestPayload<'a> {
|
||||
pub fn new(prompt: &'a str, objective: Option<&'a str>) -> Self {
|
||||
Self {
|
||||
kind: "direction_request",
|
||||
prompt,
|
||||
objective,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DirectorRole {
|
||||
hub: Arc<CrossSessionHub>,
|
||||
run_id: String,
|
||||
role: String,
|
||||
timeout: Duration,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
}
|
||||
|
||||
impl DirectorRole {
|
||||
pub fn new(
|
||||
hub: Arc<CrossSessionHub>,
|
||||
run_id: impl Into<String>,
|
||||
role: impl Into<String>,
|
||||
timeout: Duration,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
hub,
|
||||
run_id: run_id.into(),
|
||||
role: role.into(),
|
||||
timeout,
|
||||
progress,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn response_schema() -> Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["directive", "rationale"],
|
||||
"properties": {
|
||||
"directive": { "type": "string" },
|
||||
"rationale": { "type": ["string", "null"] }
|
||||
},
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Role<DirectionRequestPayload<'_>, DirectiveResponse> for DirectorRole {
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
req: &'a DirectionRequestPayload<'a>,
|
||||
) -> futures::future::BoxFuture<'a, Result<DirectiveResponse>> {
|
||||
Box::pin(async move {
|
||||
let request_text = serde_json::to_string_pretty(req)?;
|
||||
let handle = session::post_turn(
|
||||
self.hub.as_ref(),
|
||||
&self.run_id,
|
||||
&self.role,
|
||||
request_text,
|
||||
Some(Self::response_schema()),
|
||||
)
|
||||
.await?;
|
||||
let progress = self
|
||||
.progress
|
||||
.as_deref()
|
||||
.map(|reporter| (reporter, self.role.as_str()));
|
||||
let response: AssistantMessage =
|
||||
session::await_first_idle(self.hub.as_ref(), &handle, self.timeout, progress)
|
||||
.await?;
|
||||
parse_json_struct(&response.message.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
49
codex-rs/codex-infty/src/roles/mod.rs
Normal file
49
codex-rs/codex-infty/src/roles/mod.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use anyhow::Result;
|
||||
use futures::future::BoxFuture;
|
||||
|
||||
pub mod director;
|
||||
pub mod solver;
|
||||
pub mod verifier;
|
||||
pub mod verifier_pool;
|
||||
|
||||
pub trait Role<Req, Resp> {
|
||||
fn call<'a>(&'a self, req: &'a Req) -> BoxFuture<'a, Result<Resp>>;
|
||||
}
|
||||
|
||||
// Shared helpers used by role implementations
|
||||
use anyhow::Context as _;
|
||||
use anyhow::anyhow;
|
||||
use std::any::type_name;
|
||||
|
||||
pub(crate) fn strip_json_code_fence(text: &str) -> Option<&str> {
|
||||
let trimmed = text.trim();
|
||||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||||
return rest.strip_suffix("```").map(str::trim);
|
||||
}
|
||||
if let Some(rest) = trimmed.strip_prefix("```JSON") {
|
||||
return rest.strip_suffix("```").map(str::trim);
|
||||
}
|
||||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||
return rest.strip_suffix("```").map(str::trim);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn parse_json_struct<T>(message: &str) -> Result<T>
|
||||
where
|
||||
T: serde::de::DeserializeOwned,
|
||||
{
|
||||
let trimmed = message.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Err(anyhow!("message was empty"));
|
||||
}
|
||||
|
||||
serde_json::from_str(trimmed)
|
||||
.or_else(|err| {
|
||||
strip_json_code_fence(trimmed)
|
||||
.map(|inner| serde_json::from_str(inner))
|
||||
.unwrap_or_else(|| Err(err))
|
||||
})
|
||||
.map_err(|err| anyhow!(err))
|
||||
.with_context(|| format!("failed to parse message as {}", type_name::<T>()))
|
||||
}
|
||||
202
codex-rs/codex-infty/src/roles/solver.rs
Normal file
202
codex-rs/codex-infty/src/roles/solver.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::cross_session::AssistantMessage;
|
||||
use codex_core::cross_session::CrossSessionHub;
|
||||
use codex_core::cross_session::SessionEventStream;
|
||||
use codex_protocol::ConversationId;
|
||||
use serde::de::Error as _;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::progress::ProgressReporter;
|
||||
use crate::roles::Role;
|
||||
use crate::session;
|
||||
use crate::signals::AggregatedVerifierVerdict;
|
||||
use crate::signals::DirectiveResponse;
|
||||
|
||||
pub struct SolverRole {
|
||||
hub: Arc<CrossSessionHub>,
|
||||
run_id: String,
|
||||
role: String,
|
||||
conversation_id: ConversationId,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
}
|
||||
|
||||
impl SolverRole {
|
||||
pub fn new(
|
||||
hub: Arc<CrossSessionHub>,
|
||||
run_id: impl Into<String>,
|
||||
role: impl Into<String>,
|
||||
conversation_id: ConversationId,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
hub,
|
||||
run_id: run_id.into(),
|
||||
role: role.into(),
|
||||
conversation_id,
|
||||
progress,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn solver_signal_schema() -> Value {
|
||||
// Only allow asking the director or sending the final result.
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": { "type": "string", "enum": ["direction_request", "final_delivery"] },
|
||||
"prompt": { "type": ["string", "null"] },
|
||||
"deliverable_path": { "type": ["string", "null"] },
|
||||
"summary": { "type": ["string", "null"] }
|
||||
},
|
||||
"required": ["type", "prompt", "deliverable_path", "summary"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
pub fn final_delivery_schema() -> Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["type", "deliverable_path", "summary"],
|
||||
"properties": {
|
||||
"type": { "const": "final_delivery" },
|
||||
"deliverable_path": { "type": "string" },
|
||||
"summary": { "type": ["string", "null"] }
|
||||
},
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn post(
|
||||
&self,
|
||||
text: impl Into<String>,
|
||||
final_output_json_schema: Option<Value>,
|
||||
) -> Result<()> {
|
||||
let _ = session::post_turn(
|
||||
self.hub.as_ref(),
|
||||
&self.run_id,
|
||||
&self.role,
|
||||
text,
|
||||
final_output_json_schema,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn stream_events(
|
||||
&self,
|
||||
) -> Result<SessionEventStream, codex_core::cross_session::CrossSessionError> {
|
||||
self.hub.stream_events(self.conversation_id)
|
||||
}
|
||||
|
||||
pub async fn request_finalization_signal(&self) -> Result<()> {
|
||||
let handle = session::post_turn(
|
||||
self.hub.as_ref(),
|
||||
&self.run_id,
|
||||
&self.role,
|
||||
crate::types::FINALIZATION_PROMPT,
|
||||
Some(Self::final_delivery_schema()),
|
||||
)
|
||||
.await?;
|
||||
// Allow more time for the solver to start emitting the
|
||||
// finalization signal before timing out as "idle".
|
||||
let _ =
|
||||
session::await_first_idle(self.hub.as_ref(), &handle, Duration::from_secs(120), None)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SolverPost {
|
||||
pub text: String,
|
||||
pub final_output_json_schema: Option<Value>,
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
pub enum SolverRequest {
|
||||
Directive(DirectiveResponse),
|
||||
VerificationSummary(AggregatedVerifierVerdict),
|
||||
}
|
||||
|
||||
impl From<DirectiveResponse> for SolverRequest {
|
||||
fn from(d: DirectiveResponse) -> Self {
|
||||
SolverRequest::Directive(d)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&AggregatedVerifierVerdict> for SolverRequest {
|
||||
fn from(v: &AggregatedVerifierVerdict) -> Self {
|
||||
SolverRequest::VerificationSummary(v.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl SolverRequest {
|
||||
fn to_text(&self) -> Result<String> {
|
||||
match self {
|
||||
SolverRequest::Directive(d) => Ok(serde_json::to_string_pretty(d)?),
|
||||
SolverRequest::VerificationSummary(s) => Ok(serde_json::to_string_pretty(s)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Role<SolverPost, AssistantMessage> for SolverRole {
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
req: &'a SolverPost,
|
||||
) -> futures::future::BoxFuture<'a, Result<AssistantMessage>> {
|
||||
Box::pin(async move {
|
||||
let handle = session::post_turn(
|
||||
self.hub.as_ref(),
|
||||
&self.run_id,
|
||||
&self.role,
|
||||
req.text.clone(),
|
||||
req.final_output_json_schema.clone(),
|
||||
)
|
||||
.await?;
|
||||
let progress = self
|
||||
.progress
|
||||
.as_deref()
|
||||
.map(|reporter| (reporter, self.role.as_str()));
|
||||
session::await_first_idle(self.hub.as_ref(), &handle, req.timeout, progress).await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Role<SolverRequest, ()> for SolverRole {
|
||||
fn call<'a>(&'a self, req: &'a SolverRequest) -> futures::future::BoxFuture<'a, Result<()>> {
|
||||
Box::pin(async move {
|
||||
let text = req.to_text()?;
|
||||
self.post(text, Some(Self::solver_signal_schema())).await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum SolverSignal {
|
||||
DirectionRequest {
|
||||
#[serde(default)]
|
||||
prompt: Option<String>,
|
||||
},
|
||||
FinalDelivery {
|
||||
#[serde(default)]
|
||||
deliverable_path: Option<String>,
|
||||
#[serde(default)]
|
||||
summary: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn parse_solver_signal(message: &str) -> Option<SolverSignal> {
|
||||
let trimmed = message.trim();
|
||||
if trimmed.is_empty() {
|
||||
return None;
|
||||
}
|
||||
serde_json::from_str(trimmed)
|
||||
.or_else(|_| {
|
||||
crate::roles::strip_json_code_fence(trimmed)
|
||||
.map(|inner| serde_json::from_str(inner.trim()))
|
||||
.unwrap_or_else(|| Err(serde_json::Error::custom("invalid payload")))
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
132
codex-rs/codex-infty/src/roles/verifier.rs
Normal file
132
codex-rs/codex-infty/src/roles/verifier.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::cross_session::AssistantMessage;
|
||||
use codex_core::cross_session::CrossSessionHub;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::progress::ProgressReporter;
|
||||
use crate::roles::Role;
|
||||
use crate::roles::parse_json_struct;
|
||||
use crate::session;
|
||||
use crate::signals::VerifierVerdict;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct VerificationRequestPayload<'a> {
|
||||
#[serde(rename = "type")]
|
||||
kind: &'static str,
|
||||
pub claim_path: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub notes: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub objective: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> VerificationRequestPayload<'a> {
|
||||
pub fn new(claim_path: &'a str, notes: Option<&'a str>, objective: Option<&'a str>) -> Self {
|
||||
Self {
|
||||
kind: "verification_request",
|
||||
claim_path,
|
||||
notes,
|
||||
objective,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VerifierRole {
|
||||
hub: Arc<CrossSessionHub>,
|
||||
run_id: String,
|
||||
role: String,
|
||||
timeout: Duration,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
}
|
||||
|
||||
impl VerifierRole {
|
||||
pub fn new(
|
||||
hub: Arc<CrossSessionHub>,
|
||||
run_id: impl Into<String>,
|
||||
role: impl Into<String>,
|
||||
timeout: Duration,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
hub,
|
||||
run_id: run_id.into(),
|
||||
role: role.into(),
|
||||
timeout,
|
||||
progress,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn role(&self) -> &str {
|
||||
&self.role
|
||||
}
|
||||
|
||||
pub fn response_schema() -> Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["verdict", "reasons", "suggestions"],
|
||||
"properties": {
|
||||
"verdict": { "type": "string", "enum": ["pass", "fail"] },
|
||||
"reasons": { "type": "array", "items": { "type": "string" } },
|
||||
"suggestions": { "type": "array", "items": { "type": "string" } }
|
||||
},
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Role<VerificationRequestPayload<'_>, VerifierVerdict> for VerifierRole {
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
req: &'a VerificationRequestPayload<'a>,
|
||||
) -> futures::future::BoxFuture<'a, Result<VerifierVerdict>> {
|
||||
Box::pin(async move {
|
||||
let request_text = serde_json::to_string_pretty(req)?;
|
||||
let handle = session::post_turn(
|
||||
self.hub.as_ref(),
|
||||
&self.run_id,
|
||||
&self.role,
|
||||
request_text,
|
||||
Some(Self::response_schema()),
|
||||
)
|
||||
.await?;
|
||||
let progress = self
|
||||
.progress
|
||||
.as_deref()
|
||||
.map(|reporter| (reporter, self.role.as_str()));
|
||||
let response: AssistantMessage =
|
||||
session::await_first_idle(self.hub.as_ref(), &handle, self.timeout, progress)
|
||||
.await?;
|
||||
parse_json_struct(&response.message.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn aggregate_verdicts(items: Vec<(String, VerifierVerdict)>) -> AggregatedVerifierVerdict {
|
||||
let mut overall = VerifierDecision::Pass;
|
||||
let mut verdicts = Vec::with_capacity(items.len());
|
||||
|
||||
for (role, verdict) in items {
|
||||
if !verdict.verdict.is_pass() {
|
||||
overall = VerifierDecision::Fail;
|
||||
}
|
||||
verdicts.push(VerifierReport {
|
||||
role,
|
||||
verdict: verdict.verdict,
|
||||
reasons: verdict.reasons,
|
||||
suggestions: verdict.suggestions,
|
||||
});
|
||||
}
|
||||
|
||||
AggregatedVerifierVerdict {
|
||||
kind: "verification_feedback",
|
||||
overall,
|
||||
verdicts,
|
||||
}
|
||||
}
|
||||
use crate::signals::AggregatedVerifierVerdict;
|
||||
use crate::signals::VerifierDecision;
|
||||
use crate::signals::VerifierReport;
|
||||
153
codex-rs/codex-infty/src/roles/verifier_pool.rs
Normal file
153
codex-rs/codex-infty/src/roles/verifier_pool.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::cross_session::CrossSessionHub;
|
||||
use codex_core::protocol::Op;
|
||||
|
||||
use crate::progress::ProgressReporter;
|
||||
use crate::prompts;
|
||||
use crate::roles::Role;
|
||||
use crate::roles::verifier::VerificationRequestPayload;
|
||||
use crate::roles::verifier::VerifierRole;
|
||||
use crate::roles::verifier::aggregate_verdicts;
|
||||
use crate::session;
|
||||
use crate::signals::AggregatedVerifierVerdict;
|
||||
use crate::signals::VerifierVerdict;
|
||||
use crate::types::RoleConfig;
|
||||
use crate::types::RunSessions;
|
||||
|
||||
pub struct VerificationRound {
|
||||
pub summary: AggregatedVerifierVerdict,
|
||||
pub passing_roles: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct VerifierPool {
|
||||
hub: Arc<CrossSessionHub>,
|
||||
run_id: String,
|
||||
timeout: Duration,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
roles: Vec<VerifierRole>,
|
||||
}
|
||||
|
||||
impl VerifierPool {
|
||||
pub fn from_sessions(
|
||||
hub: Arc<CrossSessionHub>,
|
||||
sessions: &RunSessions,
|
||||
timeout: Duration,
|
||||
progress: Option<Arc<dyn ProgressReporter>>,
|
||||
) -> Self {
|
||||
let roles = sessions
|
||||
.verifiers
|
||||
.iter()
|
||||
.map(|v| {
|
||||
VerifierRole::new(
|
||||
Arc::clone(&hub),
|
||||
sessions.run_id.clone(),
|
||||
v.role.clone(),
|
||||
timeout,
|
||||
progress.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
hub,
|
||||
run_id: sessions.run_id.clone(),
|
||||
timeout,
|
||||
progress,
|
||||
roles,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.roles.is_empty()
|
||||
}
|
||||
|
||||
pub async fn collect_round(
|
||||
&self,
|
||||
request: &VerificationRequestPayload<'_>,
|
||||
) -> Result<VerificationRound> {
|
||||
let futures = self
|
||||
.roles
|
||||
.iter()
|
||||
.map(|role| async {
|
||||
let name = role.role().to_string();
|
||||
let verdict = role.call(request).await;
|
||||
(name, verdict)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let joined = futures::future::join_all(futures).await;
|
||||
|
||||
let mut results: Vec<(String, VerifierVerdict)> = Vec::with_capacity(joined.len());
|
||||
let mut passing_roles: Vec<String> = Vec::new();
|
||||
for (name, verdict_res) in joined.into_iter() {
|
||||
let verdict = verdict_res
|
||||
.with_context(|| format!("verifier {} returned invalid verdict JSON", name))?;
|
||||
if let Some(progress) = self.progress.as_ref() {
|
||||
progress.verifier_verdict(&name, &verdict);
|
||||
}
|
||||
if verdict.verdict.is_pass() {
|
||||
passing_roles.push(name.clone());
|
||||
}
|
||||
results.push((name, verdict));
|
||||
}
|
||||
let summary = aggregate_verdicts(results);
|
||||
Ok(VerificationRound {
|
||||
summary,
|
||||
passing_roles,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn replace_role(&mut self, role_name: &str) {
|
||||
if let Some(idx) = self.roles.iter().position(|v| v.role() == role_name) {
|
||||
self.roles[idx] = VerifierRole::new(
|
||||
Arc::clone(&self.hub),
|
||||
self.run_id.clone(),
|
||||
role_name.to_string(),
|
||||
self.timeout,
|
||||
self.progress.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn rotate_passing(
|
||||
&mut self,
|
||||
sessions: &mut RunSessions,
|
||||
manager: &ConversationManager,
|
||||
passing_roles: &[String],
|
||||
) -> Result<()> {
|
||||
for role in passing_roles {
|
||||
// find existing index
|
||||
let Some(idx) = sessions.verifiers.iter().position(|s| &s.role == role) else {
|
||||
continue;
|
||||
};
|
||||
let old = &sessions.verifiers[idx];
|
||||
// best-effort shutdown and unregister
|
||||
let _ = old.conversation.submit(Op::Shutdown).await;
|
||||
let _ = manager.remove_conversation(&old.conversation_id).await;
|
||||
|
||||
// Reuse the existing verifier's config so overrides (e.g., base_url in tests)
|
||||
// are preserved when respawning a passing verifier.
|
||||
let config = old.config.clone();
|
||||
let role_config = RoleConfig::new(role.to_string(), config);
|
||||
let run_path = sessions.store.path();
|
||||
let session = session::spawn_role(
|
||||
Arc::clone(&self.hub),
|
||||
manager,
|
||||
&self.run_id,
|
||||
run_path,
|
||||
role_config,
|
||||
prompts::ensure_instructions,
|
||||
)
|
||||
.await?;
|
||||
sessions
|
||||
.store
|
||||
.update_rollout_path(&session.role, session.rollout_path.clone())?;
|
||||
sessions.verifiers[idx] = session;
|
||||
self.replace_role(role);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
211
codex-rs/codex-infty/src/run_store.rs
Normal file
211
codex-rs/codex-infty/src/run_store.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
const ARTIFACTS_DIR: &str = "artifacts";
|
||||
const MEMORY_DIR: &str = "memory";
|
||||
const INDEX_DIR: &str = "index";
|
||||
const DELIVERABLE_DIR: &str = "deliverable";
|
||||
const METADATA_FILE: &str = "run.json";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoleMetadata {
|
||||
pub role: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub rollout_path: Option<PathBuf>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub config_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RunMetadata {
|
||||
pub run_id: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub roles: Vec<RoleMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RunStore {
|
||||
path: PathBuf,
|
||||
metadata: RunMetadata,
|
||||
}
|
||||
|
||||
impl RunStore {
|
||||
pub fn initialize(
|
||||
run_path: impl AsRef<Path>,
|
||||
run_id: &str,
|
||||
roles: &[RoleMetadata],
|
||||
) -> Result<Self> {
|
||||
let run_path = run_path.as_ref().to_path_buf();
|
||||
fs::create_dir_all(&run_path)
|
||||
.with_context(|| format!("failed to create run directory {}", run_path.display()))?;
|
||||
|
||||
for child in [ARTIFACTS_DIR, MEMORY_DIR, INDEX_DIR, DELIVERABLE_DIR] {
|
||||
fs::create_dir_all(run_path.join(child))
|
||||
.with_context(|| format!("failed to create subdirectory {child}"))?;
|
||||
}
|
||||
|
||||
let metadata_path = run_path.join(METADATA_FILE);
|
||||
if metadata_path.exists() {
|
||||
return Err(anyhow!(
|
||||
"run metadata already exists at {}",
|
||||
metadata_path.display()
|
||||
));
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
let metadata = RunMetadata {
|
||||
run_id: run_id.to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
roles: roles.to_vec(),
|
||||
};
|
||||
write_metadata(&metadata_path, &metadata)?;
|
||||
|
||||
Ok(Self {
|
||||
path: run_path,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(run_path: impl AsRef<Path>) -> Result<Self> {
|
||||
let run_path = run_path.as_ref().to_path_buf();
|
||||
let metadata_path = run_path.join(METADATA_FILE);
|
||||
let metadata: RunMetadata = serde_json::from_slice(
|
||||
&fs::read(&metadata_path)
|
||||
.with_context(|| format!("failed to read {}", metadata_path.display()))?,
|
||||
)
|
||||
.with_context(|| format!("failed to parse {}", metadata_path.display()))?;
|
||||
|
||||
Ok(Self {
|
||||
path: run_path,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
|
||||
pub fn metadata(&self) -> &RunMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
pub fn role_metadata(&self, role: &str) -> Option<&RoleMetadata> {
|
||||
self.metadata.roles.iter().find(|meta| meta.role == role)
|
||||
}
|
||||
|
||||
pub fn update_rollout_path(&mut self, role: &str, rollout_path: PathBuf) -> Result<()> {
|
||||
if let Some(meta) = self
|
||||
.metadata
|
||||
.roles
|
||||
.iter_mut()
|
||||
.find(|meta| meta.role == role)
|
||||
{
|
||||
meta.rollout_path = Some(rollout_path);
|
||||
self.commit_metadata()
|
||||
} else {
|
||||
Err(anyhow!("role {role} not found in run store"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_role_config_path(&mut self, role: &str, path: PathBuf) -> Result<()> {
|
||||
if let Some(meta) = self
|
||||
.metadata
|
||||
.roles
|
||||
.iter_mut()
|
||||
.find(|meta| meta.role == role)
|
||||
{
|
||||
meta.config_path = Some(path);
|
||||
self.commit_metadata()
|
||||
} else {
|
||||
Err(anyhow!("role {role} not found in run store"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn touch(&mut self) -> Result<()> {
|
||||
self.metadata.updated_at = Utc::now();
|
||||
self.commit_metadata()
|
||||
}
|
||||
|
||||
fn commit_metadata(&mut self) -> Result<()> {
|
||||
self.metadata.updated_at = Utc::now();
|
||||
let metadata_path = self.path.join(METADATA_FILE);
|
||||
write_metadata(&metadata_path, &self.metadata)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_metadata(path: &Path, metadata: &RunMetadata) -> Result<()> {
|
||||
let parent = path
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow!("metadata path must have parent"))?;
|
||||
let mut temp = NamedTempFile::new_in(parent)
|
||||
.with_context(|| format!("failed to create temp file in {}", parent.display()))?;
|
||||
serde_json::to_writer_pretty(&mut temp, metadata)?;
|
||||
temp.flush()?;
|
||||
temp.persist(path)
|
||||
.with_context(|| format!("failed to persist metadata to {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn initialize_creates_directories_and_metadata() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let run_path = temp.path().join("run_1");
|
||||
let roles = vec![
|
||||
RoleMetadata {
|
||||
role: "solver".into(),
|
||||
rollout_path: None,
|
||||
config_path: None,
|
||||
},
|
||||
RoleMetadata {
|
||||
role: "director".into(),
|
||||
rollout_path: None,
|
||||
config_path: None,
|
||||
},
|
||||
];
|
||||
|
||||
let store = RunStore::initialize(&run_path, "run_1", &roles).unwrap();
|
||||
assert!(store.path().join(ARTIFACTS_DIR).is_dir());
|
||||
assert!(store.path().join(MEMORY_DIR).is_dir());
|
||||
assert!(store.path().join(INDEX_DIR).is_dir());
|
||||
assert!(store.path().join(DELIVERABLE_DIR).is_dir());
|
||||
assert_eq!(store.metadata().roles.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_rollout_persists_metadata() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let run_path = temp.path().join("run_2");
|
||||
let roles = vec![RoleMetadata {
|
||||
role: "solver".into(),
|
||||
rollout_path: None,
|
||||
config_path: None,
|
||||
}];
|
||||
let mut store = RunStore::initialize(&run_path, "run_2", &roles).unwrap();
|
||||
let rollout = PathBuf::from("/tmp/rollout.jsonl");
|
||||
store
|
||||
.update_rollout_path("solver", rollout.clone())
|
||||
.unwrap();
|
||||
|
||||
let loaded = RunStore::load(&run_path).unwrap();
|
||||
let solver = loaded.role_metadata("solver").unwrap();
|
||||
assert_eq!(solver.rollout_path.as_ref().unwrap(), &rollout);
|
||||
}
|
||||
}
|
||||
112
codex-rs/codex-infty/src/session.rs
Normal file
112
codex-rs/codex-infty/src/session.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::CrossSessionSpawnParams;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::cross_session::AssistantMessage;
|
||||
use codex_core::cross_session::CrossSessionError;
|
||||
use codex_core::cross_session::CrossSessionHub;
|
||||
use codex_core::cross_session::PostUserTurnRequest;
|
||||
use codex_core::cross_session::RoleOrId;
|
||||
use codex_core::cross_session::TurnHandle;
|
||||
use serde_json::Value;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::StreamExt as _;
|
||||
|
||||
use crate::progress::ProgressReporter;
|
||||
use crate::types::RoleConfig;
|
||||
use crate::types::RoleSession;
|
||||
|
||||
pub async fn spawn_role(
|
||||
hub: Arc<CrossSessionHub>,
|
||||
manager: &ConversationManager,
|
||||
run_id: &str,
|
||||
run_path: &Path,
|
||||
role_config: RoleConfig,
|
||||
ensure_instructions: impl FnOnce(&str, &mut Config),
|
||||
) -> Result<RoleSession> {
|
||||
let RoleConfig {
|
||||
role, mut config, ..
|
||||
} = role_config;
|
||||
config.cwd = run_path.to_path_buf();
|
||||
ensure_instructions(&role, &mut config);
|
||||
let cfg_for_session = config.clone();
|
||||
let session = manager
|
||||
.new_conversation_with_cross_session(
|
||||
cfg_for_session,
|
||||
CrossSessionSpawnParams {
|
||||
hub: Arc::clone(&hub),
|
||||
run_id: Some(run_id.to_string()),
|
||||
role: Some(role.clone()),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
// Note: include the final config used to spawn the session
|
||||
Ok(RoleSession::from_new(role, session, config))
|
||||
}
|
||||
|
||||
// resumable runs are disabled for now; resume_role removed
|
||||
|
||||
pub async fn post_turn(
|
||||
hub: &CrossSessionHub,
|
||||
run_id: &str,
|
||||
role: &str,
|
||||
text: impl Into<String>,
|
||||
final_output_json_schema: Option<Value>,
|
||||
) -> Result<TurnHandle, CrossSessionError> {
|
||||
hub.post_user_turn(PostUserTurnRequest {
|
||||
target: RoleOrId::RunRole {
|
||||
run_id: run_id.to_string(),
|
||||
role: role.to_string(),
|
||||
},
|
||||
text: text.into(),
|
||||
final_output_json_schema,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn await_first_idle(
|
||||
hub: &CrossSessionHub,
|
||||
handle: &TurnHandle,
|
||||
idle_timeout: Duration,
|
||||
progress: Option<(&dyn ProgressReporter, &str)>,
|
||||
) -> Result<AssistantMessage> {
|
||||
let mut events = hub.stream_events(handle.conversation_id())?;
|
||||
let wait_first = hub.await_first_assistant(handle, idle_timeout);
|
||||
tokio::pin!(wait_first);
|
||||
|
||||
let idle = tokio::time::sleep(idle_timeout);
|
||||
tokio::pin!(idle);
|
||||
|
||||
let submission_id = handle.submission_id().to_string();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = &mut wait_first => {
|
||||
return result.map_err(|err| anyhow!(err));
|
||||
}
|
||||
maybe_event = events.next() => {
|
||||
let Some(event) = maybe_event else {
|
||||
bail!(CrossSessionError::SessionClosed);
|
||||
};
|
||||
if event.event.id == submission_id {
|
||||
if let Some((reporter, role)) = progress {
|
||||
reporter.role_event(role, &event.event.msg);
|
||||
}
|
||||
if let codex_core::protocol::EventMsg::Error(err) = &event.event.msg {
|
||||
bail!(anyhow!(err.message.clone()));
|
||||
}
|
||||
idle.as_mut().reset(Instant::now() + idle_timeout);
|
||||
}
|
||||
}
|
||||
_ = &mut idle => {
|
||||
bail!(CrossSessionError::AwaitTimeout(idle_timeout));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
55
codex-rs/codex-infty/src/signals.rs
Normal file
55
codex-rs/codex-infty/src/signals.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct DirectiveResponse {
|
||||
pub directive: String,
|
||||
#[serde(default)]
|
||||
pub rationale: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum VerifierDecision {
|
||||
Pass,
|
||||
Fail,
|
||||
}
|
||||
|
||||
impl VerifierDecision {
|
||||
pub fn is_pass(self) -> bool {
|
||||
matches!(self, VerifierDecision::Pass)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct VerifierVerdict {
|
||||
pub verdict: VerifierDecision,
|
||||
#[serde(default)]
|
||||
pub reasons: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub suggestions: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct VerifierReport {
|
||||
pub role: String,
|
||||
pub verdict: VerifierDecision,
|
||||
#[serde(default)]
|
||||
pub reasons: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub suggestions: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct AggregatedVerifierVerdict {
|
||||
#[serde(rename = "type")]
|
||||
pub kind: &'static str,
|
||||
pub overall: VerifierDecision,
|
||||
pub verdicts: Vec<VerifierReport>,
|
||||
}
|
||||
|
||||
impl From<&AggregatedVerifierVerdict> for String {
|
||||
fn from(value: &AggregatedVerifierVerdict) -> Self {
|
||||
serde_json::to_string_pretty(value).unwrap_or_else(|_| "{}".to_string())
|
||||
}
|
||||
}
|
||||
103
codex-rs/codex-infty/src/types.rs
Normal file
103
codex-rs/codex-infty/src/types.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_protocol::ConversationId;
|
||||
|
||||
pub(crate) const DEFAULT_DIRECTOR_TIMEOUT: Duration = Duration::from_secs(1200);
|
||||
pub(crate) const DEFAULT_VERIFIER_TIMEOUT: Duration = Duration::from_secs(1800);
|
||||
pub(crate) const FINALIZATION_PROMPT: &str = "Create deliverable/: include compiled artifacts or scripts, usage docs, and tests. Write deliverable/summary.txt capturing the final answer, evidence, and follow-up steps. Also provide deliverable/README.md with overview, manifest (paths and sizes), verification steps, and limitations. Remove scratch files. Reply with JSON: {\"type\":\"final_delivery\",\"deliverable_path\":\"deliverable/summary.txt\",\"summary\":\"<answer plus supporting context>\"}.";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RoleConfig {
|
||||
pub role: String,
|
||||
pub config: Config,
|
||||
pub config_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl RoleConfig {
|
||||
pub fn new(role: impl Into<String>, mut config: Config) -> Self {
|
||||
config.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
config.approval_policy = AskForApproval::Never;
|
||||
Self {
|
||||
role: role.into(),
|
||||
config,
|
||||
config_path: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_path(role: impl Into<String>, config: Config, config_path: PathBuf) -> Self {
|
||||
Self {
|
||||
role: role.into(),
|
||||
config,
|
||||
config_path: Some(config_path),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RunParams {
|
||||
pub run_id: String,
|
||||
pub run_root: Option<PathBuf>,
|
||||
pub solver: RoleConfig,
|
||||
pub director: RoleConfig,
|
||||
pub verifiers: Vec<RoleConfig>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RunExecutionOptions {
|
||||
pub objective: Option<String>,
|
||||
pub director_timeout: Duration,
|
||||
pub verifier_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for RunExecutionOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
objective: None,
|
||||
director_timeout: DEFAULT_DIRECTOR_TIMEOUT,
|
||||
verifier_timeout: DEFAULT_VERIFIER_TIMEOUT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RunOutcome {
|
||||
pub run_id: String,
|
||||
pub deliverable_path: PathBuf,
|
||||
pub summary: Option<String>,
|
||||
pub raw_message: String,
|
||||
}
|
||||
|
||||
pub struct RoleSession {
|
||||
pub role: String,
|
||||
pub conversation_id: ConversationId,
|
||||
pub conversation: Arc<CodexConversation>,
|
||||
pub session_configured: codex_core::protocol::SessionConfiguredEvent,
|
||||
pub rollout_path: PathBuf,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl RoleSession {
|
||||
pub(crate) fn from_new(role: String, session: NewConversation, config: Config) -> Self {
|
||||
Self {
|
||||
role,
|
||||
conversation_id: session.conversation_id,
|
||||
conversation: session.conversation,
|
||||
session_configured: session.session_configured.clone(),
|
||||
rollout_path: session.session_configured.rollout_path.clone(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RunSessions {
|
||||
pub run_id: String,
|
||||
pub solver: RoleSession,
|
||||
pub director: RoleSession,
|
||||
pub verifiers: Vec<RoleSession>,
|
||||
pub store: crate::RunStore,
|
||||
}
|
||||
91
codex-rs/codex-infty/src/utils.rs
Normal file
91
codex-rs/codex-infty/src/utils.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
|
||||
pub fn trim_to_non_empty(opt: Option<String>) -> Option<String> {
|
||||
opt.and_then(|s| {
|
||||
let trimmed = s.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn required_trimmed(opt: Option<String>, err_msg: &str) -> Result<String> {
|
||||
trim_to_non_empty(opt).ok_or_else(|| anyhow!(err_msg.to_string()))
|
||||
}
|
||||
|
||||
pub fn resolve_deliverable_path(base: &Path, candidate: &str) -> Result<PathBuf> {
|
||||
let base_abs = base
|
||||
.canonicalize()
|
||||
.with_context(|| format!("failed to canonicalize run store {}", base.display()))?;
|
||||
|
||||
let candidate_path = Path::new(candidate);
|
||||
let joined = if candidate_path.is_absolute() {
|
||||
candidate_path.to_path_buf()
|
||||
} else {
|
||||
base_abs.join(candidate_path)
|
||||
};
|
||||
|
||||
let resolved = joined.canonicalize().with_context(|| {
|
||||
format!(
|
||||
"failed to canonicalize deliverable path {}",
|
||||
joined.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if !resolved.starts_with(&base_abs) {
|
||||
bail!(
|
||||
"deliverable path {} escapes run store {}",
|
||||
resolved.display(),
|
||||
base_abs.display()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(resolved)
|
||||
}
|
||||
|
||||
pub fn objective_as_str(options: &crate::types::RunExecutionOptions) -> Option<&str> {
|
||||
options
|
||||
.objective
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn resolve_deliverable_within_base() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let base = tmp.path();
|
||||
std::fs::create_dir_all(base.join("deliverable")).unwrap();
|
||||
std::fs::write(base.join("deliverable").join("a.txt"), "ok").unwrap();
|
||||
let resolved = resolve_deliverable_path(base, "deliverable/a.txt").unwrap();
|
||||
let base_abs = base.canonicalize().unwrap();
|
||||
assert!(resolved.starts_with(&base_abs));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_deliverable_rejects_escape() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let base = tmp.path();
|
||||
// Create a real file outside of base so canonicalization succeeds
|
||||
let outside = TempDir::new().unwrap();
|
||||
let outside_file = outside.path().join("outside.txt");
|
||||
std::fs::write(&outside_file, "nope").unwrap();
|
||||
|
||||
let err = resolve_deliverable_path(base, outside_file.to_str().unwrap()).unwrap_err();
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("escapes run store"));
|
||||
}
|
||||
}
|
||||
327
codex-rs/codex-infty/tests/orchestrator.rs
Normal file
327
codex-rs/codex-infty/tests/orchestrator.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::cross_session::AssistantMessage;
|
||||
use codex_core::cross_session::PostUserTurnRequest;
|
||||
use codex_core::cross_session::RoleOrId;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_infty::InftyOrchestrator;
|
||||
use codex_infty::RoleConfig;
|
||||
use codex_infty::RunExecutionOptions;
|
||||
use codex_infty::RunParams;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::MockServer;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn orchestrator_routes_between_roles_and_records_store() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let bodies = vec![
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-1"),
|
||||
responses::ev_assistant_message("solver-msg-1", "Need direction"),
|
||||
responses::ev_completed("solver-resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("director-resp-1"),
|
||||
responses::ev_assistant_message("director-msg-1", "Proceed iteratively"),
|
||||
responses::ev_completed("director-resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-2"),
|
||||
responses::ev_assistant_message("solver-msg-2", "Acknowledged"),
|
||||
responses::ev_completed("solver-resp-2"),
|
||||
]),
|
||||
];
|
||||
let response_mock = responses::mount_sse_sequence(&server, bodies).await;
|
||||
|
||||
let runs_root = TempDir::new()?;
|
||||
let orchestrator =
|
||||
InftyOrchestrator::with_runs_root(CodexAuth::from_api_key("dummy-key"), runs_root.path());
|
||||
let run_id = "run-orchestrator".to_string();
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let director_config = build_config(&server).await?;
|
||||
|
||||
let sessions = orchestrator
|
||||
.spawn_run(RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(runs_root.path().join("runs").join(&run_id)),
|
||||
solver: RoleConfig::new("solver", solver_config.clone()),
|
||||
director: RoleConfig::new("director", director_config.clone()),
|
||||
verifiers: Vec::new(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let solver_message = call_role(
|
||||
&orchestrator,
|
||||
&sessions.run_id,
|
||||
"solver",
|
||||
"kick off plan",
|
||||
Duration::from_secs(1),
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(solver_message.message.message, "Need direction");
|
||||
|
||||
let director_message = relay_assistant_to_role(
|
||||
&orchestrator,
|
||||
&sessions.run_id,
|
||||
"director",
|
||||
&solver_message,
|
||||
Duration::from_secs(1),
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(director_message.message.message, "Proceed iteratively");
|
||||
|
||||
let solver_reply = relay_assistant_to_role(
|
||||
&orchestrator,
|
||||
&sessions.run_id,
|
||||
"solver",
|
||||
&director_message,
|
||||
Duration::from_secs(1),
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(solver_reply.message.message, "Acknowledged");
|
||||
|
||||
assert_eq!(response_mock.requests().len(), 3);
|
||||
let first_request = response_mock.requests().first().unwrap().body_json();
|
||||
let instructions = first_request["instructions"]
|
||||
.as_str()
|
||||
.expect("request should set instructions");
|
||||
assert!(
|
||||
instructions.contains("brilliant mathematician"),
|
||||
"missing solver prompt: {instructions}"
|
||||
);
|
||||
assert!(sessions.store.path().is_dir());
|
||||
let solver_meta = sessions.store.role_metadata("solver").unwrap();
|
||||
assert!(solver_meta.rollout_path.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// resumable runs are disabled; resume test removed
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn execute_new_run_drives_to_completion() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let bodies = vec![
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-1"),
|
||||
responses::ev_assistant_message(
|
||||
"solver-msg-1",
|
||||
r#"{"type":"direction_request","prompt":"Need directive","claim_path":null,"notes":null,"deliverable_path":null,"summary":null}"#,
|
||||
),
|
||||
responses::ev_completed("solver-resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("director-resp-1"),
|
||||
responses::ev_assistant_message(
|
||||
"director-msg-1",
|
||||
r#"{"directive":"Proceed","rationale":"Follow the plan"}"#,
|
||||
),
|
||||
responses::ev_completed("director-resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-2"),
|
||||
responses::ev_assistant_message("solver-msg-2", "Acknowledged"),
|
||||
responses::ev_assistant_message(
|
||||
"solver-msg-4",
|
||||
r#"{"type":"final_delivery","prompt":null,"claim_path":null,"notes":null,"deliverable_path":"deliverable","summary":"done"}"#,
|
||||
),
|
||||
responses::ev_completed("solver-resp-2"),
|
||||
]),
|
||||
// Final verification of the deliverable
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("verifier-resp-3"),
|
||||
responses::ev_assistant_message(
|
||||
"verifier-msg-3",
|
||||
r#"{"verdict":"pass","reasons":[],"suggestions":[]}"#,
|
||||
),
|
||||
responses::ev_completed("verifier-resp-3"),
|
||||
]),
|
||||
// Feedback turn summarizing the verification outcome back to the solver
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-5"),
|
||||
responses::ev_completed("solver-resp-5"),
|
||||
]),
|
||||
];
|
||||
for body in bodies {
|
||||
responses::mount_sse_once(&server, body).await;
|
||||
}
|
||||
|
||||
let runs_root = TempDir::new()?;
|
||||
let orchestrator =
|
||||
InftyOrchestrator::with_runs_root(CodexAuth::from_api_key("dummy-key"), runs_root.path());
|
||||
let run_id = "run-auto".to_string();
|
||||
let run_root = runs_root.path().join("runs").join(&run_id);
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let director_config = build_config(&server).await?;
|
||||
let verifier_config = build_config(&server).await?;
|
||||
|
||||
let options = RunExecutionOptions {
|
||||
objective: Some("Implement feature".to_string()),
|
||||
..RunExecutionOptions::default()
|
||||
};
|
||||
|
||||
let outcome = orchestrator
|
||||
.execute_new_run(
|
||||
RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(run_root.clone()),
|
||||
solver: RoleConfig::new("solver", solver_config),
|
||||
director: RoleConfig::new("director", director_config),
|
||||
verifiers: vec![RoleConfig::new("verifier", verifier_config)],
|
||||
},
|
||||
options,
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(outcome.run_id, run_id);
|
||||
assert_eq!(outcome.summary.as_deref(), Some("done"));
|
||||
assert!(outcome.raw_message.contains("final_delivery"));
|
||||
let canonical_run_root = std::fs::canonicalize(&run_root)?;
|
||||
let canonical_deliverable = std::fs::canonicalize(&outcome.deliverable_path)?;
|
||||
assert!(canonical_deliverable.starts_with(&canonical_run_root));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn spawn_run_cleans_up_on_failure() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let bodies = vec![
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-1"),
|
||||
responses::ev_completed("solver-resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("director-resp-1"),
|
||||
responses::ev_completed("director-resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("dup-resp"),
|
||||
responses::ev_completed("dup-resp"),
|
||||
]),
|
||||
];
|
||||
for body in bodies {
|
||||
responses::mount_sse_once(&server, body).await;
|
||||
}
|
||||
|
||||
let runs_root = TempDir::new()?;
|
||||
let orchestrator =
|
||||
InftyOrchestrator::with_runs_root(CodexAuth::from_api_key("dummy-key"), runs_root.path());
|
||||
let run_id = "run-cleanup".to_string();
|
||||
let run_path = runs_root.path().join("runs").join(&run_id);
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let director_config = build_config(&server).await?;
|
||||
|
||||
let result = orchestrator
|
||||
.spawn_run(RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(run_path.clone()),
|
||||
solver: RoleConfig::new("solver", solver_config.clone()),
|
||||
director: RoleConfig::new("director", director_config.clone()),
|
||||
verifiers: vec![RoleConfig::new("solver", solver_config.clone())],
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
assert!(!run_path.exists(), "failed run should remove run directory");
|
||||
|
||||
let bodies = vec![
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-2"),
|
||||
responses::ev_completed("solver-resp-2"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("director-resp-2"),
|
||||
responses::ev_completed("director-resp-2"),
|
||||
]),
|
||||
];
|
||||
for body in bodies {
|
||||
responses::mount_sse_once(&server, body).await;
|
||||
}
|
||||
|
||||
let sessions = orchestrator
|
||||
.spawn_run(RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(run_path.clone()),
|
||||
solver: RoleConfig::new("solver", solver_config),
|
||||
director: RoleConfig::new("director", director_config),
|
||||
verifiers: Vec::new(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
sessions.solver.conversation.submit(Op::Shutdown).await.ok();
|
||||
sessions
|
||||
.director
|
||||
.conversation
|
||||
.submit(Op::Shutdown)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn build_config(server: &MockServer) -> anyhow::Result<Config> {
|
||||
let home = TempDir::new()?;
|
||||
let cwd = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
let mut provider = built_in_model_providers()["openai"].clone();
|
||||
provider.base_url = Some(format!("{}/v1", server.uri()));
|
||||
config.model_provider = provider;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
async fn call_role(
|
||||
orchestrator: &InftyOrchestrator,
|
||||
run_id: &str,
|
||||
role: &str,
|
||||
text: &str,
|
||||
timeout: Duration,
|
||||
) -> anyhow::Result<AssistantMessage> {
|
||||
let hub = orchestrator.hub();
|
||||
let handle = hub
|
||||
.post_user_turn(PostUserTurnRequest {
|
||||
target: RoleOrId::RunRole {
|
||||
run_id: run_id.to_string(),
|
||||
role: role.to_string(),
|
||||
},
|
||||
text: text.to_string(),
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
let reply = hub.await_first_assistant(&handle, timeout).await?;
|
||||
Ok(reply)
|
||||
}
|
||||
|
||||
async fn relay_assistant_to_role(
|
||||
orchestrator: &InftyOrchestrator,
|
||||
run_id: &str,
|
||||
target_role: &str,
|
||||
assistant: &AssistantMessage,
|
||||
timeout: Duration,
|
||||
) -> anyhow::Result<AssistantMessage> {
|
||||
call_role(
|
||||
orchestrator,
|
||||
run_id,
|
||||
target_role,
|
||||
&assistant.message.message,
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
324
codex-rs/codex-infty/tests/schemas.rs
Normal file
324
codex-rs/codex-infty/tests/schemas.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::Config;
|
||||
use codex_infty::InftyOrchestrator;
|
||||
use codex_infty::RoleConfig;
|
||||
use codex_infty::RunExecutionOptions;
|
||||
use codex_infty::RunParams;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::MockServer;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn director_request_includes_output_schema() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
// 1) Solver: emit a direction_request so the orchestrator calls Director.
|
||||
let body_solver = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-1"),
|
||||
responses::ev_assistant_message(
|
||||
"solver-msg-1",
|
||||
r#"{"type":"direction_request","prompt":"Need directive","claim_path":null,"notes":null,"deliverable_path":null,"summary":null}"#,
|
||||
),
|
||||
responses::ev_completed("solver-resp-1"),
|
||||
]);
|
||||
let _mock_solver = responses::mount_sse_once(&server, body_solver).await;
|
||||
|
||||
// 2) Director: reply with a directive JSON.
|
||||
let body_director = responses::sse(vec![
|
||||
responses::ev_response_created("director-resp-1"),
|
||||
responses::ev_assistant_message(
|
||||
"director-msg-1",
|
||||
r#"{"directive":"Proceed","rationale":"Follow the plan"}"#,
|
||||
),
|
||||
responses::ev_completed("director-resp-1"),
|
||||
]);
|
||||
let mock_director = responses::mount_sse_once(&server, body_director).await;
|
||||
|
||||
// 3) After relaying directive back to Solver, we do not need to continue the run.
|
||||
// Provide a short empty solver completion body to avoid hanging HTTP calls.
|
||||
let body_solver_after = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-2"),
|
||||
responses::ev_completed("solver-resp-2"),
|
||||
]);
|
||||
let _mock_solver_after = responses::mount_sse_once(&server, body_solver_after).await;
|
||||
|
||||
let runs_root = TempDir::new()?;
|
||||
let orchestrator =
|
||||
InftyOrchestrator::with_runs_root(CodexAuth::from_api_key("dummy-key"), runs_root.path());
|
||||
let run_id = "run-director-schema".to_string();
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let director_config = build_config(&server).await?;
|
||||
|
||||
let params = RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(runs_root.path().join("runs").join(&run_id)),
|
||||
solver: RoleConfig::new("solver", solver_config),
|
||||
director: RoleConfig::new("director", director_config),
|
||||
verifiers: Vec::new(),
|
||||
};
|
||||
|
||||
let options = RunExecutionOptions {
|
||||
objective: Some("Kick off".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Drive the run in the background; we'll assert the request shape then cancel.
|
||||
let fut = tokio::spawn(async move {
|
||||
let _ = orchestrator.execute_new_run(params, options).await;
|
||||
});
|
||||
|
||||
// Wait until the Director request is captured.
|
||||
wait_for_requests(&mock_director, 1, Duration::from_secs(2)).await;
|
||||
let req = mock_director.single_request();
|
||||
let body = req.body_json();
|
||||
|
||||
// Assert that a JSON schema was sent under text.format.
|
||||
let text = &body["text"]; // Optional; present when using schemas
|
||||
assert!(text.is_object(), "missing text controls in request body");
|
||||
let fmt = &text["format"];
|
||||
assert!(fmt.is_object(), "missing text.format in request body");
|
||||
assert_eq!(fmt["type"], "json_schema");
|
||||
let schema = &fmt["schema"];
|
||||
assert!(schema.is_object(), "missing text.format.schema");
|
||||
assert_eq!(schema["type"], "object");
|
||||
// Ensure the directive property exists and is a string.
|
||||
assert_eq!(schema["properties"]["directive"]["type"], "string");
|
||||
// Enforce strictness: required must include all properties.
|
||||
let required = schema["required"]
|
||||
.as_array()
|
||||
.expect("required must be array");
|
||||
let props = schema["properties"]
|
||||
.as_object()
|
||||
.expect("properties must be object");
|
||||
for key in props.keys() {
|
||||
assert!(
|
||||
required.iter().any(|v| v == key),
|
||||
"missing {key} in required"
|
||||
);
|
||||
}
|
||||
// Ensure the objective text appears in the serialized request body
|
||||
let raw = serde_json::to_string(&body).expect("serialize body");
|
||||
assert!(
|
||||
raw.contains("Kick off"),
|
||||
"objective missing from director request body"
|
||||
);
|
||||
|
||||
// Stop the background task to end the test.
|
||||
fut.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn final_delivery_request_includes_output_schema() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
// 1) Solver: emit empty message so orchestrator asks for final_delivery via schema.
|
||||
let body_solver = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-1"),
|
||||
// No signal -> orchestrator will prompt with final_output schema.
|
||||
responses::ev_completed("solver-resp-1"),
|
||||
]);
|
||||
let _mock_solver = responses::mount_sse_once(&server, body_solver).await;
|
||||
|
||||
// 2) Capture the schema-bearing request to Solver.
|
||||
let body_solver_prompt = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-2"),
|
||||
responses::ev_assistant_message(
|
||||
"solver-msg-2",
|
||||
r#"{"type":"final_delivery","deliverable_path":"deliverable/summary.txt","summary":null}"#,
|
||||
),
|
||||
responses::ev_completed("solver-resp-2"),
|
||||
]);
|
||||
let mock_solver_prompt = responses::mount_sse_once(&server, body_solver_prompt).await;
|
||||
|
||||
// 3) Keep any follow-up quiet.
|
||||
let body_solver_done = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-3"),
|
||||
responses::ev_completed("solver-resp-3"),
|
||||
]);
|
||||
let _mock_solver_done = responses::mount_sse_once(&server, body_solver_done).await;
|
||||
|
||||
let runs_root = TempDir::new()?;
|
||||
let orchestrator =
|
||||
InftyOrchestrator::with_runs_root(CodexAuth::from_api_key("dummy-key"), runs_root.path());
|
||||
let run_id = "run-final-schema".to_string();
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let director_config = build_config(&server).await?;
|
||||
|
||||
let params = RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(runs_root.path().join("runs").join(&run_id)),
|
||||
solver: RoleConfig::new("solver", solver_config),
|
||||
director: RoleConfig::new("director", director_config),
|
||||
verifiers: Vec::new(),
|
||||
};
|
||||
|
||||
let options = RunExecutionOptions {
|
||||
objective: Some("Kick off".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let fut = tokio::spawn(async move {
|
||||
let _ = orchestrator.execute_new_run(params, options).await;
|
||||
});
|
||||
|
||||
wait_for_requests(&mock_solver_prompt, 1, Duration::from_secs(2)).await;
|
||||
let req = mock_solver_prompt.single_request();
|
||||
let body = req.body_json();
|
||||
let text = &body["text"];
|
||||
assert!(text.is_object(), "missing text controls in request body");
|
||||
let fmt = &text["format"];
|
||||
assert!(fmt.is_object(), "missing text.format in request body");
|
||||
assert_eq!(fmt["type"], "json_schema");
|
||||
let schema = &fmt["schema"];
|
||||
assert!(schema.is_object(), "missing text.format.schema");
|
||||
let required = schema["required"]
|
||||
.as_array()
|
||||
.expect("required must be array");
|
||||
let props = schema["properties"]
|
||||
.as_object()
|
||||
.expect("properties must be object");
|
||||
for key in props.keys() {
|
||||
assert!(
|
||||
required.iter().any(|v| v == key),
|
||||
"missing {key} in required"
|
||||
);
|
||||
}
|
||||
|
||||
fut.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn verifier_request_includes_output_schema() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
// 1) Solver: issue a final_delivery which triggers verifier requests.
|
||||
let body_solver = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-1"),
|
||||
responses::ev_assistant_message(
|
||||
"solver-msg-1",
|
||||
r#"{"type":"final_delivery","deliverable_path":"deliverable/summary.txt","summary":null}"#,
|
||||
),
|
||||
responses::ev_completed("solver-resp-1"),
|
||||
]);
|
||||
let _mock_solver = responses::mount_sse_once(&server, body_solver).await;
|
||||
|
||||
// 2) Verifier: reply with a verdict JSON.
|
||||
let body_verifier = responses::sse(vec![
|
||||
responses::ev_response_created("verifier-resp-1"),
|
||||
responses::ev_assistant_message(
|
||||
"verifier-msg-1",
|
||||
r#"{"verdict":"pass","reasons":[],"suggestions":[]}"#,
|
||||
),
|
||||
responses::ev_completed("verifier-resp-1"),
|
||||
]);
|
||||
let mock_verifier = responses::mount_sse_once(&server, body_verifier).await;
|
||||
|
||||
// 3) After posting the summary back to Solver, let the request complete.
|
||||
let body_solver_after = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-2"),
|
||||
responses::ev_completed("solver-resp-2"),
|
||||
]);
|
||||
let _mock_solver_after = responses::mount_sse_once(&server, body_solver_after).await;
|
||||
|
||||
let runs_root = TempDir::new()?;
|
||||
let orchestrator =
|
||||
InftyOrchestrator::with_runs_root(CodexAuth::from_api_key("dummy-key"), runs_root.path());
|
||||
let run_id = "run-verifier-schema".to_string();
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let director_config = build_config(&server).await?;
|
||||
let verifier_config = build_config(&server).await?;
|
||||
|
||||
let params = RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(runs_root.path().join("runs").join(&run_id)),
|
||||
solver: RoleConfig::new("solver", solver_config),
|
||||
director: RoleConfig::new("director", director_config),
|
||||
verifiers: vec![RoleConfig::new("verifier", verifier_config)],
|
||||
};
|
||||
|
||||
let options = RunExecutionOptions {
|
||||
objective: Some("Kick off".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let fut = tokio::spawn(async move {
|
||||
let _ = orchestrator.execute_new_run(params, options).await;
|
||||
});
|
||||
|
||||
// Wait until the Verifier request is captured.
|
||||
wait_for_requests(&mock_verifier, 1, Duration::from_secs(2)).await;
|
||||
let req = mock_verifier.single_request();
|
||||
let body = req.body_json();
|
||||
|
||||
// Assert that a JSON schema was sent under text.format.
|
||||
let text = &body["text"]; // Optional; present when using schemas
|
||||
assert!(text.is_object(), "missing text controls in request body");
|
||||
let fmt = &text["format"];
|
||||
assert!(fmt.is_object(), "missing text.format in request body");
|
||||
assert_eq!(fmt["type"], "json_schema");
|
||||
let schema = &fmt["schema"];
|
||||
assert!(schema.is_object(), "missing text.format.schema");
|
||||
assert_eq!(schema["type"], "object");
|
||||
// Ensure the verdict property exists and is an enum of pass/fail.
|
||||
assert!(schema["properties"]["verdict"].is_object());
|
||||
// Enforce strictness: required must include all properties.
|
||||
let required = schema["required"]
|
||||
.as_array()
|
||||
.expect("required must be array");
|
||||
let props = schema["properties"]
|
||||
.as_object()
|
||||
.expect("properties must be object");
|
||||
for key in props.keys() {
|
||||
assert!(
|
||||
required.iter().any(|v| v == key),
|
||||
"missing {key} in required"
|
||||
);
|
||||
}
|
||||
|
||||
fut.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn build_config(server: &MockServer) -> anyhow::Result<Config> {
|
||||
let home = TempDir::new()?;
|
||||
let cwd = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
let mut provider = built_in_model_providers()["openai"].clone();
|
||||
provider.base_url = Some(format!("{}/v1", server.uri()));
|
||||
config.model_provider = provider;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
async fn wait_for_requests(mock: &responses::ResponseMock, min: usize, timeout: Duration) {
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
let start = Instant::now();
|
||||
loop {
|
||||
if mock.requests().len() >= min {
|
||||
return;
|
||||
}
|
||||
if start.elapsed() > timeout {
|
||||
return;
|
||||
}
|
||||
sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
}
|
||||
98
codex-rs/codex-infty/tests/timeouts.rs
Normal file
98
codex-rs/codex-infty/tests/timeouts.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::Config;
|
||||
use codex_infty::InftyOrchestrator;
|
||||
use codex_infty::RoleConfig;
|
||||
use codex_infty::RunExecutionOptions;
|
||||
use codex_infty::RunParams;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::MockServer;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn direction_request_times_out_when_director_is_silent() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
// Solver emits a direction_request.
|
||||
let body_solver = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-1"),
|
||||
responses::ev_assistant_message(
|
||||
"solver-msg-1",
|
||||
r#"{"type":"direction_request","prompt":"Need directive","claim_path":null,"notes":null,"deliverable_path":null,"summary":null}"#,
|
||||
),
|
||||
responses::ev_completed("solver-resp-1"),
|
||||
]);
|
||||
let _mock_solver = responses::mount_sse_once(&server, body_solver).await;
|
||||
|
||||
// Director remains silent (no assistant message); the model completes immediately.
|
||||
let body_director_silent = responses::sse(vec![
|
||||
responses::ev_response_created("director-resp-1"),
|
||||
// intentionally no message
|
||||
responses::ev_completed("director-resp-1"),
|
||||
]);
|
||||
let _mock_director = responses::mount_sse_once(&server, body_director_silent).await;
|
||||
|
||||
// After attempting to relay a directive back to the solver, orchestrator won't proceed
|
||||
// as we will time out waiting for the director; however, the solver will still receive
|
||||
// a follow-up post later in the flow, so we pre-mount an empty completion to satisfy it
|
||||
// if the code ever reaches that point in future changes.
|
||||
let body_solver_after = responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-2"),
|
||||
responses::ev_completed("solver-resp-2"),
|
||||
]);
|
||||
let _mock_solver_after = responses::mount_sse_once(&server, body_solver_after).await;
|
||||
|
||||
let runs_root = TempDir::new()?;
|
||||
let orchestrator =
|
||||
InftyOrchestrator::with_runs_root(CodexAuth::from_api_key("dummy-key"), runs_root.path());
|
||||
let run_id = "run-director-timeout".to_string();
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let director_config = build_config(&server).await?;
|
||||
|
||||
let params = RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(runs_root.path().join("runs").join(&run_id)),
|
||||
solver: RoleConfig::new("solver", solver_config),
|
||||
director: RoleConfig::new("director", director_config),
|
||||
verifiers: Vec::new(),
|
||||
};
|
||||
|
||||
let options = RunExecutionOptions {
|
||||
objective: Some("Kick off".to_string()),
|
||||
director_timeout: Duration::from_millis(50),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = orchestrator
|
||||
.execute_new_run(params, options)
|
||||
.await
|
||||
.err()
|
||||
.expect("expected timeout error");
|
||||
let msg = format!("{err:#}");
|
||||
assert!(
|
||||
msg.contains("timed out waiting") || msg.contains("AwaitTimeout"),
|
||||
"unexpected error: {msg}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn build_config(server: &MockServer) -> anyhow::Result<Config> {
|
||||
let home = TempDir::new()?;
|
||||
let cwd = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
let mut provider = built_in_model_providers()["openai"].clone();
|
||||
provider.base_url = Some(format!("{}/v1", server.uri()));
|
||||
config.model_provider = provider;
|
||||
Ok(config)
|
||||
}
|
||||
157
codex-rs/codex-infty/tests/verifier_replacement.rs
Normal file
157
codex-rs/codex-infty/tests/verifier_replacement.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::Config;
|
||||
use codex_infty::InftyOrchestrator;
|
||||
use codex_infty::RoleConfig;
|
||||
use codex_infty::RunExecutionOptions;
|
||||
use codex_infty::RunParams;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::MockServer;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn replaces_passing_verifiers_and_keeps_failing() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
// Round 1: alpha passes, beta fails
|
||||
let body_verifier_alpha_r1 = responses::sse(vec![
|
||||
responses::ev_response_created("verifier-alpha-r1"),
|
||||
responses::ev_assistant_message(
|
||||
"verifier-alpha-msg-r1",
|
||||
r#"{"verdict":"pass","reasons":[],"suggestions":[]}"#,
|
||||
),
|
||||
responses::ev_completed("verifier-alpha-r1"),
|
||||
]);
|
||||
let body_verifier_beta_r1 = responses::sse(vec![
|
||||
responses::ev_response_created("verifier-beta-r1"),
|
||||
responses::ev_assistant_message(
|
||||
"verifier-beta-msg-r1",
|
||||
r#"{"verdict":"fail","reasons":["missing"],"suggestions":[]}"#,
|
||||
),
|
||||
responses::ev_completed("verifier-beta-r1"),
|
||||
]);
|
||||
|
||||
// Round 2: both pass
|
||||
let body_verifier_alpha_r2 = responses::sse(vec![
|
||||
responses::ev_response_created("verifier-alpha-r2"),
|
||||
responses::ev_assistant_message(
|
||||
"verifier-alpha-msg-r2",
|
||||
r#"{"verdict":"pass","reasons":[],"suggestions":[]}"#,
|
||||
),
|
||||
responses::ev_completed("verifier-alpha-r2"),
|
||||
]);
|
||||
let body_verifier_beta_r2 = responses::sse(vec![
|
||||
responses::ev_response_created("verifier-beta-r2"),
|
||||
responses::ev_assistant_message(
|
||||
"verifier-beta-msg-r2",
|
||||
r#"{"verdict":"pass","reasons":[],"suggestions":[]}"#,
|
||||
),
|
||||
responses::ev_completed("verifier-beta-r2"),
|
||||
]);
|
||||
|
||||
// Mount verifier SSE bodies in the exact order collect_verification_summary posts to verifiers.
|
||||
// The implementation posts sequentially in the order of sessions.verifiers.
|
||||
let _m1 = responses::mount_sse_once(&server, body_verifier_alpha_r1).await;
|
||||
let _m2 = responses::mount_sse_once(&server, body_verifier_beta_r1).await;
|
||||
let _m3 = responses::mount_sse_once(&server, body_verifier_alpha_r2).await;
|
||||
let _m4 = responses::mount_sse_once(&server, body_verifier_beta_r2).await;
|
||||
|
||||
let runs_root = TempDir::new()?;
|
||||
let orchestrator =
|
||||
InftyOrchestrator::with_runs_root(CodexAuth::from_api_key("dummy-key"), runs_root.path());
|
||||
let run_id = "run-verifier-replacement".to_string();
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let director_config = build_config(&server).await?;
|
||||
let verifier_config = build_config(&server).await?;
|
||||
|
||||
// Spawn run with two verifiers in known order.
|
||||
let mut sessions = orchestrator
|
||||
.spawn_run(RunParams {
|
||||
run_id: run_id.clone(),
|
||||
run_root: Some(runs_root.path().join("runs").join(&run_id)),
|
||||
solver: RoleConfig::new("solver", solver_config),
|
||||
director: RoleConfig::new("director", director_config),
|
||||
verifiers: vec![
|
||||
RoleConfig::new("verifier-alpha", verifier_config.clone()),
|
||||
RoleConfig::new("verifier-beta", verifier_config),
|
||||
],
|
||||
})
|
||||
.await?;
|
||||
|
||||
let alpha_initial = sessions
|
||||
.store
|
||||
.role_metadata("verifier-alpha")
|
||||
.and_then(|m| m.rollout_path.clone())
|
||||
.expect("alpha initial rollout path");
|
||||
let beta_initial = sessions
|
||||
.store
|
||||
.role_metadata("verifier-beta")
|
||||
.and_then(|m| m.rollout_path.clone())
|
||||
.expect("beta initial rollout path");
|
||||
|
||||
let options = RunExecutionOptions {
|
||||
verifier_timeout: Duration::from_secs(2),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Round 1: alpha pass (should be replaced), beta fail (should be kept)
|
||||
let _summary1 = orchestrator
|
||||
.verify_round_for_test(&mut sessions, "memory/claims/c1.json", &options)
|
||||
.await?;
|
||||
|
||||
let alpha_after_r1 = sessions
|
||||
.store
|
||||
.role_metadata("verifier-alpha")
|
||||
.and_then(|m| m.rollout_path.clone())
|
||||
.expect("alpha rollout after r1");
|
||||
let beta_after_r1 = sessions
|
||||
.store
|
||||
.role_metadata("verifier-beta")
|
||||
.and_then(|m| m.rollout_path.clone())
|
||||
.expect("beta rollout after r1");
|
||||
|
||||
assert_ne!(
|
||||
alpha_initial, alpha_after_r1,
|
||||
"alpha should be replaced after pass"
|
||||
);
|
||||
assert_eq!(
|
||||
beta_initial, beta_after_r1,
|
||||
"beta should be kept after fail"
|
||||
);
|
||||
|
||||
// Round 2: both pass; beta should be replaced now.
|
||||
let _summary2 = orchestrator
|
||||
.verify_round_for_test(&mut sessions, "memory/claims/c2.json", &options)
|
||||
.await?;
|
||||
let beta_after_r2 = sessions
|
||||
.store
|
||||
.role_metadata("verifier-beta")
|
||||
.and_then(|m| m.rollout_path.clone())
|
||||
.expect("beta rollout after r2");
|
||||
assert_ne!(
|
||||
beta_initial, beta_after_r2,
|
||||
"beta should be replaced after pass in r2"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn build_config(server: &MockServer) -> anyhow::Result<Config> {
|
||||
let home = TempDir::new()?;
|
||||
let cwd = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
let mut provider = built_in_model_providers()["openai"].clone();
|
||||
provider.base_url = Some(format!("{}/v1", server.uri()));
|
||||
config.model_provider = provider;
|
||||
Ok(config)
|
||||
}
|
||||
66
codex-rs/common/src/format_env_display.rs
Normal file
66
codex-rs/common/src/format_env_display.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub fn format_env_display(env: Option<&HashMap<String, String>>, env_vars: &[String]) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
if let Some(map) = env {
|
||||
let mut pairs: Vec<_> = map.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
parts.extend(
|
||||
pairs
|
||||
.into_iter()
|
||||
.map(|(key, value)| format!("{key}={value}")),
|
||||
);
|
||||
}
|
||||
|
||||
if !env_vars.is_empty() {
|
||||
parts.extend(env_vars.iter().map(|var| format!("{var}=${var}")));
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
"-".to_string()
|
||||
} else {
|
||||
parts.join(", ")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn returns_dash_when_empty() {
|
||||
assert_eq!(format_env_display(None, &[]), "-");
|
||||
|
||||
let empty_map = HashMap::new();
|
||||
assert_eq!(format_env_display(Some(&empty_map), &[]), "-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn formats_sorted_env_pairs() {
|
||||
let mut env = HashMap::new();
|
||||
env.insert("B".to_string(), "two".to_string());
|
||||
env.insert("A".to_string(), "one".to_string());
|
||||
|
||||
assert_eq!(format_env_display(Some(&env), &[]), "A=one, B=two");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn formats_env_vars_with_dollar_prefix() {
|
||||
let vars = vec!["TOKEN".to_string(), "PATH".to_string()];
|
||||
|
||||
assert_eq!(format_env_display(None, &vars), "TOKEN=$TOKEN, PATH=$PATH");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combines_env_pairs_and_vars() {
|
||||
let mut env = HashMap::new();
|
||||
env.insert("HOME".to_string(), "/tmp".to_string());
|
||||
let vars = vec!["TOKEN".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
format_env_display(Some(&env), &vars),
|
||||
"HOME=/tmp, TOKEN=$TOKEN"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,9 @@ mod sandbox_mode_cli_arg;
|
||||
#[cfg(feature = "cli")]
|
||||
pub use sandbox_mode_cli_arg::SandboxModeCliArg;
|
||||
|
||||
#[cfg(feature = "cli")]
|
||||
pub mod format_env_display;
|
||||
|
||||
#[cfg(any(feature = "cli", test))]
|
||||
mod config_override;
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ tokio = { workspace = true, features = [
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = { workspace = true, features = ["rt"] }
|
||||
tokio-stream = { workspace = true, features = ["sync"] }
|
||||
toml = { workspace = true }
|
||||
toml_edit = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
|
||||
@@ -135,6 +135,10 @@ impl CodexAuth {
|
||||
self.get_current_token_data().and_then(|t| t.account_id)
|
||||
}
|
||||
|
||||
pub fn get_account_email(&self) -> Option<String> {
|
||||
self.get_current_token_data().and_then(|t| t.id_token.email)
|
||||
}
|
||||
|
||||
pub(crate) fn get_plan_type(&self) -> Option<PlanType> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type)
|
||||
|
||||
@@ -5,6 +5,8 @@ use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::ConnectionFailedError;
|
||||
use crate::error::ResponseStreamFailed;
|
||||
use crate::error::Result;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
@@ -309,7 +311,12 @@ pub(crate) async fn stream_chat_completions(
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
let stream = resp.bytes_stream().map_err(|e| {
|
||||
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
|
||||
source: e,
|
||||
request_id: None,
|
||||
})
|
||||
});
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
@@ -349,7 +356,9 @@ pub(crate) async fn stream_chat_completions(
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
return Err(e.into());
|
||||
return Err(CodexErr::ConnectionFailed(ConnectionFailedError {
|
||||
source: e,
|
||||
}));
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
|
||||
@@ -5,6 +5,8 @@ use std::time::Duration;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::auth::CodexAuth;
|
||||
use crate::error::ConnectionFailedError;
|
||||
use crate::error::ResponseStreamFailed;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
use bytes::Bytes;
|
||||
@@ -351,7 +353,12 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
// spawn task to process SSE
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
let stream = resp.bytes_stream().map_err(move |e| {
|
||||
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
|
||||
source: e,
|
||||
request_id: request_id.clone(),
|
||||
})
|
||||
});
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
@@ -431,7 +438,9 @@ impl ModelClient {
|
||||
request_id,
|
||||
})
|
||||
}
|
||||
Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())),
|
||||
Err(e) => Err(StreamAttemptError::RetryableTransportError(
|
||||
CodexErr::ConnectionFailed(ConnectionFailedError { source: e }),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1030,6 +1039,7 @@ mod tests {
|
||||
"test",
|
||||
"test",
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -28,6 +28,12 @@ use futures::future::BoxFuture;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use serde_json;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -445,6 +451,7 @@ impl Session {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
auth_manager.auth().and_then(|a| a.get_account_id()),
|
||||
auth_manager.auth().and_then(|a| a.get_account_email()),
|
||||
auth_manager.auth().map(|a| a.mode),
|
||||
config.otel.log_user_prompt,
|
||||
terminal::user_agent(),
|
||||
@@ -620,6 +627,7 @@ impl Session {
|
||||
warn!("Overwriting existing pending approval for sub_id: {event_id}");
|
||||
}
|
||||
|
||||
let parsed_cmd = parse_command(&command);
|
||||
let event = Event {
|
||||
id: event_id,
|
||||
msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
@@ -627,6 +635,7 @@ impl Session {
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
parsed_cmd,
|
||||
}),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
@@ -882,10 +891,7 @@ impl Session {
|
||||
call_id,
|
||||
command: command_for_display.clone(),
|
||||
cwd,
|
||||
parsed_cmd: parse_command(&command_for_display)
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect(),
|
||||
parsed_cmd: parse_command(&command_for_display),
|
||||
}),
|
||||
};
|
||||
let event = Event {
|
||||
@@ -910,6 +916,7 @@ impl Session {
|
||||
duration,
|
||||
exit_code,
|
||||
timed_out: _,
|
||||
..
|
||||
} = output;
|
||||
// Send full stdout/stderr to clients; do not truncate.
|
||||
let stdout = stdout.text.clone();
|
||||
@@ -974,15 +981,28 @@ impl Session {
|
||||
let sub_id = context.sub_id.clone();
|
||||
let call_id = context.call_id.clone();
|
||||
|
||||
self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
|
||||
.await;
|
||||
|
||||
let begin_turn_diff = turn_diff_tracker.clone();
|
||||
let begin_context = context.clone();
|
||||
let session = self;
|
||||
let result = self
|
||||
.services
|
||||
.executor
|
||||
.run(request, self, approval_policy, &context)
|
||||
.run(request, self, approval_policy, &context, move || {
|
||||
let turn_diff = begin_turn_diff.clone();
|
||||
let ctx = begin_context.clone();
|
||||
async move {
|
||||
session.on_exec_command_begin(turn_diff, ctx).await;
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
if matches!(
|
||||
&result,
|
||||
Err(ExecError::Function(FunctionCallError::Denied(_)))
|
||||
) {
|
||||
return result;
|
||||
}
|
||||
|
||||
let normalized = normalize_exec_result(&result);
|
||||
let borrowed = normalized.event_output();
|
||||
|
||||
@@ -1014,11 +1034,16 @@ impl Session {
|
||||
}
|
||||
|
||||
async fn notify_stream_error(&self, sub_id: &str, message: impl Into<String>) {
|
||||
let message = message.into();
|
||||
warn!(
|
||||
conversation_id = %self.conversation_id,
|
||||
sub_id = %sub_id,
|
||||
%message,
|
||||
"stream error while streaming model response",
|
||||
);
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::StreamError(StreamErrorEvent {
|
||||
message: message.into(),
|
||||
}),
|
||||
msg: EventMsg::StreamError(StreamErrorEvent { message }),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
}
|
||||
@@ -1057,6 +1082,39 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourcesRequestParams>,
|
||||
) -> anyhow::Result<ListResourcesResult> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.list_resources(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_resource_templates(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
) -> anyhow::Result<ListResourceTemplatesResult> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.list_resource_templates(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
) -> anyhow::Result<ReadResourceResult> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read_resource(server, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
server: &str,
|
||||
@@ -1419,16 +1477,23 @@ async fn submission_loop(
|
||||
|
||||
// This is a cheap lookup from the connection manager's cache.
|
||||
let tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let auth_statuses = compute_auth_statuses(
|
||||
config.mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
)
|
||||
.await;
|
||||
let (auth_statuses, resources, resource_templates) = tokio::join!(
|
||||
compute_auth_statuses(
|
||||
config.mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
),
|
||||
sess.services.mcp_connection_manager.list_all_resources(),
|
||||
sess.services
|
||||
.mcp_connection_manager
|
||||
.list_all_resource_templates()
|
||||
);
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::McpListToolsResponse(
|
||||
crate::protocol::McpListToolsResponseEvent {
|
||||
tools,
|
||||
resources,
|
||||
resource_templates,
|
||||
auth_statuses,
|
||||
},
|
||||
),
|
||||
@@ -2216,7 +2281,8 @@ async fn try_run_turn(
|
||||
response: Some(response),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
Err(FunctionCallError::RespondToModel(message))
|
||||
| Err(FunctionCallError::Denied(message)) => {
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
@@ -2745,6 +2811,7 @@ mod tests {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::config_types::DEFAULT_OTEL_ENVIRONMENT;
|
||||
use crate::config_types::History;
|
||||
use crate::config_types::McpServerConfig;
|
||||
use crate::config_types::McpServerTransportConfig;
|
||||
use crate::config_types::Notice;
|
||||
use crate::config_types::Notifications;
|
||||
use crate::config_types::OtelConfig;
|
||||
use crate::config_types::OtelConfigToml;
|
||||
@@ -28,6 +29,8 @@ use crate::model_family::find_family_for_model;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::project_doc::DEFAULT_PROJECT_DOC_FILENAME;
|
||||
use crate::project_doc::LOCAL_PROJECT_DOC_FILENAME;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use anyhow::Context;
|
||||
@@ -40,6 +43,7 @@ use codex_protocol::config_types::Verbosity;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use dirs::home_dir;
|
||||
use serde::Deserialize;
|
||||
use similar::DiffableStr;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::io::ErrorKind;
|
||||
@@ -98,6 +102,10 @@ pub struct Config {
|
||||
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
|
||||
/// True if the user passed in an override or set a value in config.toml
|
||||
/// for either of approval_policy or sandbox_mode.
|
||||
pub did_user_set_custom_approval_policy_or_sandbox_mode: bool,
|
||||
|
||||
pub shell_environment_policy: ShellEnvironmentPolicy,
|
||||
|
||||
/// When `true`, `AgentReasoning` events emitted by the backend will be
|
||||
@@ -228,9 +236,16 @@ pub struct Config {
|
||||
/// The active profile name used to derive this `Config` (if any).
|
||||
pub active_profile: Option<String>,
|
||||
|
||||
/// The currently active project config, resolved by checking if cwd:
|
||||
/// is (1) part of a git repo, (2) a git worktree, or (3) just using the cwd
|
||||
pub active_project: ProjectConfig,
|
||||
|
||||
/// Tracks whether the Windows onboarding screen has been acknowledged.
|
||||
pub windows_wsl_setup_acknowledged: bool,
|
||||
|
||||
/// Collection of various notices we show the user
|
||||
pub notices: Notice,
|
||||
|
||||
/// When true, disables burst-paste detection for typed input entirely.
|
||||
/// All characters are inserted as they are received, and no buffering
|
||||
/// or placeholder replacement will occur for fast keypress bursts.
|
||||
@@ -373,7 +388,13 @@ pub fn write_global_mcp_servers(
|
||||
let mut entry = TomlTable::new();
|
||||
entry.set_implicit(false);
|
||||
match &config.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
entry["command"] = toml_edit::value(command.clone());
|
||||
|
||||
if !args.is_empty() {
|
||||
@@ -396,15 +417,50 @@ pub fn write_global_mcp_servers(
|
||||
}
|
||||
entry["env"] = TomlItem::Table(env_table);
|
||||
}
|
||||
|
||||
if !env_vars.is_empty() {
|
||||
entry["env_vars"] =
|
||||
TomlItem::Value(env_vars.iter().collect::<TomlArray>().into());
|
||||
}
|
||||
|
||||
if let Some(cwd) = cwd {
|
||||
entry["cwd"] = toml_edit::value(cwd.to_string_lossy().to_string());
|
||||
}
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
entry["url"] = toml_edit::value(url.clone());
|
||||
if let Some(env_var) = bearer_token_env_var {
|
||||
entry["bearer_token_env_var"] = toml_edit::value(env_var.clone());
|
||||
}
|
||||
if let Some(headers) = http_headers
|
||||
&& !headers.is_empty()
|
||||
{
|
||||
let mut table = TomlTable::new();
|
||||
table.set_implicit(false);
|
||||
let mut pairs: Vec<_> = headers.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
for (key, value) in pairs {
|
||||
table.insert(key, toml_edit::value(value.clone()));
|
||||
}
|
||||
entry["http_headers"] = TomlItem::Table(table);
|
||||
}
|
||||
if let Some(headers) = env_http_headers
|
||||
&& !headers.is_empty()
|
||||
{
|
||||
let mut table = TomlTable::new();
|
||||
table.set_implicit(false);
|
||||
let mut pairs: Vec<_> = headers.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
for (key, value) in pairs {
|
||||
table.insert(key, toml_edit::value(value.clone()));
|
||||
}
|
||||
entry["env_http_headers"] = TomlItem::Table(table);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -546,6 +602,54 @@ pub fn set_windows_wsl_setup_acknowledged(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist the acknowledgement flag for the full access warning prompt.
|
||||
pub fn set_hide_full_access_warning(codex_home: &Path, acknowledged: bool) -> anyhow::Result<()> {
|
||||
let config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let mut doc = match std::fs::read_to_string(config_path.clone()) {
|
||||
Ok(s) => s.parse::<DocumentMut>()?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => DocumentMut::new(),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
let notices_table = load_or_create_top_level_table(&mut doc, Notice::TABLE_KEY)?;
|
||||
|
||||
notices_table["hide_full_access_warning"] = toml_edit::value(acknowledged);
|
||||
|
||||
std::fs::create_dir_all(codex_home)?;
|
||||
let tmp_file = NamedTempFile::new_in(codex_home)?;
|
||||
std::fs::write(tmp_file.path(), doc.to_string())?;
|
||||
tmp_file.persist(config_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_or_create_top_level_table<'a>(
|
||||
doc: &'a mut DocumentMut,
|
||||
key: &str,
|
||||
) -> anyhow::Result<&'a mut toml_edit::Table> {
|
||||
let mut created_table = false;
|
||||
|
||||
let root = doc.as_table_mut();
|
||||
let needs_table =
|
||||
!root.contains_key(key) || root.get(key).and_then(|item| item.as_table()).is_none();
|
||||
if needs_table {
|
||||
root.insert(key, toml_edit::table());
|
||||
created_table = true;
|
||||
}
|
||||
|
||||
let Some(table) = doc[key].as_table_mut() else {
|
||||
return Err(anyhow::anyhow!(format!(
|
||||
"table [{key}] missing after initialization"
|
||||
)));
|
||||
};
|
||||
|
||||
if created_table {
|
||||
table.set_implicit(true);
|
||||
}
|
||||
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
fn ensure_profile_table<'a>(
|
||||
doc: &'a mut DocumentMut,
|
||||
profile_name: &str,
|
||||
@@ -821,6 +925,10 @@ pub struct ConfigToml {
|
||||
/// Tracks whether the Windows onboarding screen has been acknowledged.
|
||||
pub windows_wsl_setup_acknowledged: Option<bool>,
|
||||
|
||||
/// Collection of in-product notices (different from notifications)
|
||||
/// See [`crate::config_types::Notices`] for more details
|
||||
pub notice: Option<Notice>,
|
||||
|
||||
/// Legacy, now use features
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
pub experimental_use_exec_command_tool: Option<bool>,
|
||||
@@ -857,6 +965,15 @@ pub struct ProjectConfig {
|
||||
pub trust_level: Option<String>,
|
||||
}
|
||||
|
||||
impl ProjectConfig {
|
||||
pub fn is_trusted(&self) -> bool {
|
||||
match &self.trust_level {
|
||||
Some(trust_level) => trust_level == "trusted",
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct ToolsToml {
|
||||
#[serde(default, alias = "web_search_request")]
|
||||
@@ -878,9 +995,23 @@ impl From<ToolsToml> for Tools {
|
||||
|
||||
impl ConfigToml {
|
||||
/// Derive the effective sandbox policy from the configuration.
|
||||
fn derive_sandbox_policy(&self, sandbox_mode_override: Option<SandboxMode>) -> SandboxPolicy {
|
||||
fn derive_sandbox_policy(
|
||||
&self,
|
||||
sandbox_mode_override: Option<SandboxMode>,
|
||||
resolved_cwd: &Path,
|
||||
) -> SandboxPolicy {
|
||||
let resolved_sandbox_mode = sandbox_mode_override
|
||||
.or(self.sandbox_mode)
|
||||
.or_else(|| {
|
||||
// if no sandbox_mode is set, but user has marked directory as trusted, use WorkspaceWrite
|
||||
self.get_active_project(resolved_cwd).and_then(|p| {
|
||||
if p.is_trusted() {
|
||||
Some(SandboxMode::WorkspaceWrite)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
match resolved_sandbox_mode {
|
||||
SandboxMode::ReadOnly => SandboxPolicy::new_read_only_policy(),
|
||||
@@ -902,30 +1033,26 @@ impl ConfigToml {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cwd_trusted(&self, resolved_cwd: &Path) -> bool {
|
||||
/// Resolves the cwd to an existing project, or returns None if ConfigToml
|
||||
/// does not contain a project corresponding to cwd or a git repo for cwd
|
||||
pub fn get_active_project(&self, resolved_cwd: &Path) -> Option<ProjectConfig> {
|
||||
let projects = self.projects.clone().unwrap_or_default();
|
||||
|
||||
let is_path_trusted = |path: &Path| {
|
||||
let path_str = path.to_string_lossy().to_string();
|
||||
projects
|
||||
.get(&path_str)
|
||||
.map(|p| p.trust_level.as_deref() == Some("trusted"))
|
||||
.unwrap_or(false)
|
||||
};
|
||||
|
||||
// Fast path: exact cwd match
|
||||
if is_path_trusted(resolved_cwd) {
|
||||
return true;
|
||||
if let Some(project_config) = projects.get(&resolved_cwd.to_string_lossy().to_string()) {
|
||||
return Some(project_config.clone());
|
||||
}
|
||||
|
||||
// If cwd lives inside a git worktree, check whether the root git project
|
||||
// If cwd lives inside a git repo/worktree, check whether the root git project
|
||||
// (the primary repository working directory) is trusted. This lets
|
||||
// worktrees inherit trust from the main project.
|
||||
if let Some(root_project) = resolve_root_git_project_for_trust(resolved_cwd) {
|
||||
return is_path_trusted(&root_project);
|
||||
if let Some(repo_root) = resolve_root_git_project_for_trust(resolved_cwd)
|
||||
&& let Some(project_config_for_root) =
|
||||
projects.get(&repo_root.to_string_lossy().to_string_lossy().to_string())
|
||||
{
|
||||
return Some(project_config_for_root.clone());
|
||||
}
|
||||
|
||||
false
|
||||
None
|
||||
}
|
||||
|
||||
pub fn get_config_profile(
|
||||
@@ -984,7 +1111,7 @@ impl Config {
|
||||
model,
|
||||
review_model: override_review_model,
|
||||
cwd,
|
||||
approval_policy,
|
||||
approval_policy: approval_policy_override,
|
||||
sandbox_mode,
|
||||
model_provider,
|
||||
config_profile: config_profile_key,
|
||||
@@ -1024,7 +1151,47 @@ impl Config {
|
||||
|
||||
let features = Features::from_config(&cfg, &config_profile, feature_overrides);
|
||||
|
||||
let sandbox_policy = cfg.derive_sandbox_policy(sandbox_mode);
|
||||
let resolved_cwd = {
|
||||
use std::env;
|
||||
|
||||
match cwd {
|
||||
None => {
|
||||
tracing::info!("cwd not set, using current dir");
|
||||
env::current_dir()?
|
||||
}
|
||||
Some(p) if p.is_absolute() => p,
|
||||
Some(p) => {
|
||||
// Resolve relative path against the current working directory.
|
||||
tracing::info!("cwd is relative, resolving against current dir");
|
||||
let mut current = env::current_dir()?;
|
||||
current.push(p);
|
||||
current
|
||||
}
|
||||
}
|
||||
};
|
||||
let active_project = cfg
|
||||
.get_active_project(&resolved_cwd)
|
||||
.unwrap_or(ProjectConfig { trust_level: None });
|
||||
|
||||
let sandbox_policy = cfg.derive_sandbox_policy(sandbox_mode, &resolved_cwd);
|
||||
let mut approval_policy = approval_policy_override
|
||||
.or(config_profile.approval_policy)
|
||||
.or(cfg.approval_policy)
|
||||
.unwrap_or_else(|| {
|
||||
if active_project.is_trusted() {
|
||||
// If no explicit approval policy is set, but we trust cwd, default to OnRequest
|
||||
AskForApproval::OnRequest
|
||||
} else {
|
||||
AskForApproval::default()
|
||||
}
|
||||
});
|
||||
let did_user_set_custom_approval_policy_or_sandbox_mode = approval_policy_override
|
||||
.is_some()
|
||||
|| config_profile.approval_policy.is_some()
|
||||
|| cfg.approval_policy.is_some()
|
||||
// TODO(#3034): profile.sandbox_mode is not implemented
|
||||
|| sandbox_mode.is_some()
|
||||
|| cfg.sandbox_mode.is_some();
|
||||
|
||||
let mut model_providers = built_in_model_providers();
|
||||
// Merge user-defined providers into the built-in list.
|
||||
@@ -1048,25 +1215,6 @@ impl Config {
|
||||
|
||||
let shell_environment_policy = cfg.shell_environment_policy.into();
|
||||
|
||||
let resolved_cwd = {
|
||||
use std::env;
|
||||
|
||||
match cwd {
|
||||
None => {
|
||||
tracing::info!("cwd not set, using current dir");
|
||||
env::current_dir()?
|
||||
}
|
||||
Some(p) if p.is_absolute() => p,
|
||||
Some(p) => {
|
||||
// Resolve relative path against the current working directory.
|
||||
tracing::info!("cwd is relative, resolving against current dir");
|
||||
let mut current = env::current_dir()?;
|
||||
current.push(p);
|
||||
current
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let history = cfg.history.unwrap_or_default();
|
||||
|
||||
let include_plan_tool_flag = features.enabled(Feature::PlanTool);
|
||||
@@ -1123,6 +1271,10 @@ impl Config {
|
||||
.or(cfg.review_model)
|
||||
.unwrap_or_else(default_review_model);
|
||||
|
||||
if features.enabled(Feature::ApproveAll) {
|
||||
approval_policy = AskForApproval::OnRequest;
|
||||
}
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
review_model,
|
||||
@@ -1133,11 +1285,9 @@ impl Config {
|
||||
model_provider_id,
|
||||
model_provider,
|
||||
cwd: resolved_cwd,
|
||||
approval_policy: approval_policy
|
||||
.or(config_profile.approval_policy)
|
||||
.or(cfg.approval_policy)
|
||||
.unwrap_or_else(AskForApproval::default),
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
did_user_set_custom_approval_policy_or_sandbox_mode,
|
||||
shell_environment_policy,
|
||||
notify: cfg.notify,
|
||||
user_instructions,
|
||||
@@ -1192,7 +1342,9 @@ impl Config {
|
||||
include_view_image_tool: include_view_image_tool_flag,
|
||||
features,
|
||||
active_profile: active_profile_name,
|
||||
active_project,
|
||||
windows_wsl_setup_acknowledged: cfg.windows_wsl_setup_acknowledged.unwrap_or(false),
|
||||
notices: cfg.notice.unwrap_or_default(),
|
||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||
tui_notifications: cfg
|
||||
.tui
|
||||
@@ -1217,20 +1369,18 @@ impl Config {
|
||||
}
|
||||
|
||||
fn load_instructions(codex_dir: Option<&Path>) -> Option<String> {
|
||||
let mut p = match codex_dir {
|
||||
Some(p) => p.to_path_buf(),
|
||||
None => return None,
|
||||
};
|
||||
|
||||
p.push("AGENTS.md");
|
||||
std::fs::read_to_string(&p).ok().and_then(|s| {
|
||||
let s = s.trim();
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(s.to_string())
|
||||
let base = codex_dir?;
|
||||
for candidate in [LOCAL_PROJECT_DOC_FILENAME, DEFAULT_PROJECT_DOC_FILENAME] {
|
||||
let mut path = base.to_path_buf();
|
||||
path.push(candidate);
|
||||
if let Ok(contents) = std::fs::read_to_string(&path) {
|
||||
let trimmed = contents.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn get_base_instructions(
|
||||
@@ -1389,7 +1539,8 @@ network_access = false # This should be ignored.
|
||||
let sandbox_mode_override = None;
|
||||
assert_eq!(
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
sandbox_full_access_cfg.derive_sandbox_policy(sandbox_mode_override)
|
||||
sandbox_full_access_cfg
|
||||
.derive_sandbox_policy(sandbox_mode_override, &PathBuf::from("/tmp/test"))
|
||||
);
|
||||
|
||||
let sandbox_read_only = r#"
|
||||
@@ -1404,7 +1555,8 @@ network_access = true # This should be ignored.
|
||||
let sandbox_mode_override = None;
|
||||
assert_eq!(
|
||||
SandboxPolicy::ReadOnly,
|
||||
sandbox_read_only_cfg.derive_sandbox_policy(sandbox_mode_override)
|
||||
sandbox_read_only_cfg
|
||||
.derive_sandbox_policy(sandbox_mode_override, &PathBuf::from("/tmp/test"))
|
||||
);
|
||||
|
||||
let sandbox_workspace_write = r#"
|
||||
@@ -1428,8 +1580,57 @@ exclude_slash_tmp = true
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
},
|
||||
sandbox_workspace_write_cfg.derive_sandbox_policy(sandbox_mode_override)
|
||||
sandbox_workspace_write_cfg
|
||||
.derive_sandbox_policy(sandbox_mode_override, &PathBuf::from("/tmp/test"))
|
||||
);
|
||||
|
||||
let sandbox_workspace_write = r#"
|
||||
sandbox_mode = "workspace-write"
|
||||
|
||||
[sandbox_workspace_write]
|
||||
writable_roots = [
|
||||
"/my/workspace",
|
||||
]
|
||||
exclude_tmpdir_env_var = true
|
||||
exclude_slash_tmp = true
|
||||
|
||||
[projects."/tmp/test"]
|
||||
trust_level = "trusted"
|
||||
"#;
|
||||
|
||||
let sandbox_workspace_write_cfg = toml::from_str::<ConfigToml>(sandbox_workspace_write)
|
||||
.expect("TOML deserialization should succeed");
|
||||
let sandbox_mode_override = None;
|
||||
assert_eq!(
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![PathBuf::from("/my/workspace")],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
},
|
||||
sandbox_workspace_write_cfg
|
||||
.derive_sandbox_policy(sandbox_mode_override, &PathBuf::from("/tmp/test"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn approve_all_feature_forces_on_request_policy() -> std::io::Result<()> {
|
||||
let cfg = r#"
|
||||
[features]
|
||||
approve_all = true
|
||||
"#;
|
||||
let parsed = toml::from_str::<ConfigToml>(cfg)
|
||||
.expect("TOML deserialization should succeed for approve_all feature");
|
||||
let temp_dir = TempDir::new()?;
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
parsed,
|
||||
ConfigOverrides::default(),
|
||||
temp_dir.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert!(config.features.enabled(Feature::ApproveAll));
|
||||
assert_eq!(config.approval_policy, AskForApproval::OnRequest);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1620,6 +1821,8 @@ exclude_slash_tmp = true
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string()],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(3)),
|
||||
@@ -1633,10 +1836,18 @@ exclude_slash_tmp = true
|
||||
assert_eq!(loaded.len(), 1);
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
assert_eq!(command, "echo");
|
||||
assert_eq!(args, &vec!["hello".to_string()]);
|
||||
assert!(env.is_none());
|
||||
assert!(env_vars.is_empty());
|
||||
assert!(cwd.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1746,6 +1957,8 @@ bearer_token = "secret"
|
||||
("ZIG_VAR".to_string(), "3".to_string()),
|
||||
("ALPHA_VAR".to_string(), "1".to_string()),
|
||||
])),
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
@@ -1772,7 +1985,13 @@ ZIG_VAR = "3"
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
assert_eq!(command, "docs-server");
|
||||
assert_eq!(args, &vec!["--verbose".to_string()]);
|
||||
let env = env
|
||||
@@ -1780,6 +1999,8 @@ ZIG_VAR = "3"
|
||||
.expect("env should be preserved for stdio transport");
|
||||
assert_eq!(env.get("ALPHA_VAR"), Some(&"1".to_string()));
|
||||
assert_eq!(env.get("ZIG_VAR"), Some(&"3".to_string()));
|
||||
assert!(env_vars.is_empty());
|
||||
assert!(cwd.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1788,15 +2009,101 @@ ZIG_VAR = "3"
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> {
|
||||
async fn write_global_mcp_servers_serializes_env_vars() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut servers = BTreeMap::from([(
|
||||
let servers = BTreeMap::from([(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: "docs-server".to_string(),
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: vec!["ALPHA".to_string(), "BETA".to_string()],
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
)]);
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let serialized = std::fs::read_to_string(&config_path)?;
|
||||
assert!(
|
||||
serialized.contains(r#"env_vars = ["ALPHA", "BETA"]"#),
|
||||
"serialized config missing env_vars field:\n{serialized}"
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { env_vars, .. } => {
|
||||
assert_eq!(env_vars, &vec!["ALPHA".to_string(), "BETA".to_string()]);
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_serializes_cwd() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let cwd_path = PathBuf::from("/tmp/codex-mcp");
|
||||
let servers = BTreeMap::from([(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: "docs-server".to_string(),
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: Some(cwd_path.clone()),
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
)]);
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let serialized = std::fs::read_to_string(&config_path)?;
|
||||
assert!(
|
||||
serialized.contains(r#"cwd = "/tmp/codex-mcp""#),
|
||||
"serialized config missing cwd field:\n{serialized}"
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { cwd, .. } => {
|
||||
assert_eq!(cwd.as_deref(), Some(Path::new("/tmp/codex-mcp")));
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_streamable_http_serializes_bearer_token() -> anyhow::Result<()>
|
||||
{
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let servers = BTreeMap::from([(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: Some("MCP_TOKEN".to_string()),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(2)),
|
||||
@@ -1823,20 +2130,127 @@ startup_timeout_sec = 2.0
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert_eq!(bearer_token_env_var.as_deref(), Some("MCP_TOKEN"));
|
||||
assert!(http_headers.is_none());
|
||||
assert!(env_http_headers.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(2)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_streamable_http_serializes_custom_headers()
|
||||
-> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let servers = BTreeMap::from([(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: Some("MCP_TOKEN".to_string()),
|
||||
http_headers: Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])),
|
||||
env_http_headers: Some(HashMap::from([(
|
||||
"X-Auth".to_string(),
|
||||
"DOCS_AUTH".to_string(),
|
||||
)])),
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(2)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
)]);
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let serialized = std::fs::read_to_string(&config_path)?;
|
||||
assert_eq!(
|
||||
serialized,
|
||||
r#"[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token_env_var = "MCP_TOKEN"
|
||||
startup_timeout_sec = 2.0
|
||||
|
||||
[mcp_servers.docs.http_headers]
|
||||
X-Doc = "42"
|
||||
|
||||
[mcp_servers.docs.env_http_headers]
|
||||
X-Auth = "DOCS_AUTH"
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(
|
||||
http_headers,
|
||||
&Some(HashMap::from([("X-Doc".to_string(), "42".to_string())]))
|
||||
);
|
||||
assert_eq!(
|
||||
env_http_headers,
|
||||
&Some(HashMap::from([(
|
||||
"X-Auth".to_string(),
|
||||
"DOCS_AUTH".to_string()
|
||||
)]))
|
||||
);
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_streamable_http_removes_optional_sections()
|
||||
-> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
let mut servers = BTreeMap::from([(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: Some("MCP_TOKEN".to_string()),
|
||||
http_headers: Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])),
|
||||
env_http_headers: Some(HashMap::from([(
|
||||
"X-Auth".to_string(),
|
||||
"DOCS_AUTH".to_string(),
|
||||
)])),
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(2)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
)]);
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
let serialized_with_optional = std::fs::read_to_string(&config_path)?;
|
||||
assert!(serialized_with_optional.contains("bearer_token_env_var = \"MCP_TOKEN\""));
|
||||
assert!(serialized_with_optional.contains("[mcp_servers.docs.http_headers]"));
|
||||
assert!(serialized_with_optional.contains("[mcp_servers.docs.env_http_headers]"));
|
||||
|
||||
servers.insert(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
@@ -1859,9 +2273,112 @@ url = "https://example.com/mcp"
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token_env_var.is_none());
|
||||
assert!(http_headers.is_none());
|
||||
assert!(env_http_headers.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
|
||||
assert!(docs.startup_timeout_sec.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_global_mcp_servers_streamable_http_isolates_headers_between_servers()
|
||||
-> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
let servers = BTreeMap::from([
|
||||
(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: Some("MCP_TOKEN".to_string()),
|
||||
http_headers: Some(HashMap::from([(
|
||||
"X-Doc".to_string(),
|
||||
"42".to_string(),
|
||||
)])),
|
||||
env_http_headers: Some(HashMap::from([(
|
||||
"X-Auth".to_string(),
|
||||
"DOCS_AUTH".to_string(),
|
||||
)])),
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(2)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
),
|
||||
(
|
||||
"logs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: "logs-server".to_string(),
|
||||
args: vec!["--follow".to_string()],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
),
|
||||
]);
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let serialized = std::fs::read_to_string(&config_path)?;
|
||||
assert!(
|
||||
serialized.contains("[mcp_servers.docs.http_headers]"),
|
||||
"serialized config missing docs headers section:\n{serialized}"
|
||||
);
|
||||
assert!(
|
||||
!serialized.contains("[mcp_servers.logs.http_headers]"),
|
||||
"serialized config should not add logs headers section:\n{serialized}"
|
||||
);
|
||||
assert!(
|
||||
!serialized.contains("[mcp_servers.logs.env_http_headers]"),
|
||||
"serialized config should not add logs env headers section:\n{serialized}"
|
||||
);
|
||||
assert!(
|
||||
!serialized.contains("mcp_servers.logs.bearer_token_env_var"),
|
||||
"serialized config should not add bearer token to logs:\n{serialized}"
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path()).await?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(
|
||||
http_headers,
|
||||
&Some(HashMap::from([("X-Doc".to_string(), "42".to_string())]))
|
||||
);
|
||||
assert_eq!(
|
||||
env_http_headers,
|
||||
&Some(HashMap::from([(
|
||||
"X-Auth".to_string(),
|
||||
"DOCS_AUTH".to_string()
|
||||
)]))
|
||||
);
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
let logs = loaded.get("logs").expect("logs entry");
|
||||
match &logs.transport {
|
||||
McpServerTransportConfig::Stdio { env, .. } => {
|
||||
assert!(env.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
@@ -1880,6 +2397,8 @@ url = "https://example.com/mcp"
|
||||
command: "docs-server".to_string(),
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: false,
|
||||
startup_timeout_sec: None,
|
||||
@@ -2195,6 +2714,7 @@ model_verbosity = "high"
|
||||
model_provider: fixture.openai_provider.clone(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
did_user_set_custom_approval_policy_or_sandbox_mode: true,
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
@@ -2224,7 +2744,9 @@ model_verbosity = "high"
|
||||
include_view_image_tool: true,
|
||||
features: Features::with_defaults(),
|
||||
active_profile: Some("o3".to_string()),
|
||||
active_project: ProjectConfig { trust_level: None },
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
notices: Default::default(),
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2259,6 +2781,7 @@ model_verbosity = "high"
|
||||
model_provider: fixture.openai_chat_completions_provider.clone(),
|
||||
approval_policy: AskForApproval::UnlessTrusted,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
did_user_set_custom_approval_policy_or_sandbox_mode: true,
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
@@ -2288,7 +2811,9 @@ model_verbosity = "high"
|
||||
include_view_image_tool: true,
|
||||
features: Features::with_defaults(),
|
||||
active_profile: Some("gpt3".to_string()),
|
||||
active_project: ProjectConfig { trust_level: None },
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
notices: Default::default(),
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2338,6 +2863,7 @@ model_verbosity = "high"
|
||||
model_provider: fixture.openai_provider.clone(),
|
||||
approval_policy: AskForApproval::OnFailure,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
did_user_set_custom_approval_policy_or_sandbox_mode: true,
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
@@ -2367,7 +2893,9 @@ model_verbosity = "high"
|
||||
include_view_image_tool: true,
|
||||
features: Features::with_defaults(),
|
||||
active_profile: Some("zdr".to_string()),
|
||||
active_project: ProjectConfig { trust_level: None },
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
notices: Default::default(),
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2403,6 +2931,7 @@ model_verbosity = "high"
|
||||
model_provider: fixture.openai_provider.clone(),
|
||||
approval_policy: AskForApproval::OnFailure,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
did_user_set_custom_approval_policy_or_sandbox_mode: true,
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
user_instructions: None,
|
||||
notify: None,
|
||||
@@ -2432,7 +2961,9 @@ model_verbosity = "high"
|
||||
include_view_image_tool: true,
|
||||
features: Features::with_defaults(),
|
||||
active_profile: Some("gpt5".to_string()),
|
||||
active_project: ProjectConfig { trust_level: None },
|
||||
windows_wsl_setup_acknowledged: false,
|
||||
notices: Default::default(),
|
||||
disable_paste_burst: false,
|
||||
tui_notifications: Default::default(),
|
||||
otel: OtelConfig::default(),
|
||||
@@ -2443,6 +2974,24 @@ model_verbosity = "high"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_did_user_set_custom_approval_policy_or_sandbox_mode_defaults_no() -> anyhow::Result<()>
|
||||
{
|
||||
let fixture = create_test_fixture()?;
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
fixture.cfg.clone(),
|
||||
ConfigOverrides {
|
||||
..Default::default()
|
||||
},
|
||||
fixture.codex_home(),
|
||||
)?;
|
||||
|
||||
assert!(config.did_user_set_custom_approval_policy_or_sandbox_mode);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_project_trusted_writes_explicit_tables() -> anyhow::Result<()> {
|
||||
let project_dir = Path::new("/some/path");
|
||||
|
||||
@@ -44,16 +44,26 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
struct RawMcpServerConfig {
|
||||
// stdio
|
||||
command: Option<String>,
|
||||
#[serde(default)]
|
||||
args: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
env: Option<HashMap<String, String>>,
|
||||
#[serde(default)]
|
||||
env_vars: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
cwd: Option<PathBuf>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
#[serde(default)]
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
// streamable_http
|
||||
url: Option<String>,
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
|
||||
// shared
|
||||
#[serde(default)]
|
||||
startup_timeout_sec: Option<f64>,
|
||||
#[serde(default)]
|
||||
@@ -92,8 +102,12 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
command: Some(command),
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => {
|
||||
throw_if_set("stdio", "url", url.as_ref())?;
|
||||
@@ -102,10 +116,14 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
"bearer_token_env_var",
|
||||
bearer_token_env_var.as_ref(),
|
||||
)?;
|
||||
throw_if_set("stdio", "http_headers", http_headers.as_ref())?;
|
||||
throw_if_set("stdio", "env_http_headers", env_http_headers.as_ref())?;
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args: args.unwrap_or_default(),
|
||||
env,
|
||||
env_vars: env_vars.unwrap_or_default(),
|
||||
cwd,
|
||||
}
|
||||
}
|
||||
RawMcpServerConfig {
|
||||
@@ -115,15 +133,26 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
..
|
||||
env_vars,
|
||||
cwd,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
startup_timeout_sec: _,
|
||||
tool_timeout_sec: _,
|
||||
startup_timeout_ms: _,
|
||||
enabled: _,
|
||||
} => {
|
||||
throw_if_set("streamable_http", "command", command.as_ref())?;
|
||||
throw_if_set("streamable_http", "args", args.as_ref())?;
|
||||
throw_if_set("streamable_http", "env", env.as_ref())?;
|
||||
throw_if_set("streamable_http", "env_vars", env_vars.as_ref())?;
|
||||
throw_if_set("streamable_http", "cwd", cwd.as_ref())?;
|
||||
throw_if_set("streamable_http", "bearer_token", bearer_token.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
}
|
||||
}
|
||||
_ => return Err(SerdeError::custom("invalid transport")),
|
||||
@@ -152,6 +181,10 @@ pub enum McpServerTransportConfig {
|
||||
args: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
env: Option<HashMap<String, String>>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
env_vars: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
cwd: Option<PathBuf>,
|
||||
},
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
|
||||
StreamableHttp {
|
||||
@@ -161,6 +194,12 @@ pub enum McpServerTransportConfig {
|
||||
/// The actual secret value must be provided via the environment.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
bearer_token_env_var: Option<String>,
|
||||
/// Additional HTTP headers to include in requests to this server.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
/// HTTP headers where the value is sourced from an environment variable.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -322,6 +361,20 @@ pub struct Tui {
|
||||
pub notifications: Notifications,
|
||||
}
|
||||
|
||||
/// Settings for notices we display to users via the tui and app-server clients
|
||||
/// (primarily the Codex IDE extension). NOTE: these are different from
|
||||
/// notifications - notices are warnings, NUX screens, acknowledgements, etc.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Notice {
|
||||
/// Tracks whether the user has acknowledged the full access warning prompt.
|
||||
pub hide_full_access_warning: Option<bool>,
|
||||
}
|
||||
|
||||
impl Notice {
|
||||
/// used by set_hide_full_access_warning until we refactor config updates
|
||||
pub(crate) const TABLE_KEY: &'static str = "notice";
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct SandboxWorkspaceWrite {
|
||||
#[serde(default)]
|
||||
@@ -468,7 +521,9 @@ mod tests {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
@@ -489,7 +544,9 @@ mod tests {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string(), "world".to_string()],
|
||||
env: None
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
@@ -511,12 +568,58 @@ mod tests {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string(), "world".to_string()],
|
||||
env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())]))
|
||||
env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])),
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_env_vars() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
env_vars = ["FOO", "BAR"]
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config with env_vars");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
env_vars: vec!["FOO".to_string(), "BAR".to_string()],
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_cwd() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
cwd = "/tmp"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config with cwd");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: Some(PathBuf::from("/tmp")),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_disabled_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
@@ -543,7 +646,9 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: None
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
@@ -563,12 +668,39 @@ mod tests {
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: Some("GITHUB_TOKEN".to_string())
|
||||
bearer_token_env_var: Some("GITHUB_TOKEN".to_string()),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_headers() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
http_headers = { "X-Foo" = "bar" }
|
||||
env_http_headers = { "X-Token" = "TOKEN_ENV" }
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config with headers");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: None,
|
||||
http_headers: Some(HashMap::from([("X-Foo".to_string(), "bar".to_string())])),
|
||||
env_http_headers: Some(HashMap::from([(
|
||||
"X-Token".to_string(),
|
||||
"TOKEN_ENV".to_string()
|
||||
)])),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_command_and_url() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
@@ -591,6 +723,25 @@ mod tests {
|
||||
.expect_err("should reject env for http transport");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_headers_for_stdio() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
http_headers = { "X-Foo" = "bar" }
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject http_headers for stdio transport");
|
||||
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
env_http_headers = { "X-Foo" = "BAR_ENV" }
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject env_http_headers for stdio transport");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_inline_bearer_token_field() {
|
||||
let err = toml::from_str::<McpServerConfig>(
|
||||
|
||||
@@ -7,10 +7,16 @@ use crate::codex::compact::content_items_to_text;
|
||||
use crate::codex::compact::is_session_prefix_message;
|
||||
use crate::codex_conversation::CodexConversation;
|
||||
use crate::config::Config;
|
||||
use crate::cross_session::CrossSessionError;
|
||||
use crate::cross_session::CrossSessionHub;
|
||||
use crate::cross_session::RegisteredSession;
|
||||
use crate::cross_session::SessionDefaults;
|
||||
use crate::cross_session::SessionRegistration;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use codex_protocol::ConversationId;
|
||||
@@ -22,6 +28,7 @@ use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::warn;
|
||||
|
||||
/// Represents a newly created Codex conversation, including the first event
|
||||
/// (which is [`EventMsg::SessionConfigured`]).
|
||||
@@ -31,10 +38,17 @@ pub struct NewConversation {
|
||||
pub session_configured: SessionConfiguredEvent,
|
||||
}
|
||||
|
||||
pub struct CrossSessionSpawnParams {
|
||||
pub hub: Arc<CrossSessionHub>,
|
||||
pub run_id: Option<String>,
|
||||
pub role: Option<String>,
|
||||
}
|
||||
|
||||
/// [`ConversationManager`] is responsible for creating conversations and
|
||||
/// maintaining them in memory.
|
||||
pub struct ConversationManager {
|
||||
conversations: Arc<RwLock<HashMap<ConversationId, Arc<CodexConversation>>>>,
|
||||
cross_session_registrations: Arc<RwLock<HashMap<ConversationId, RegisteredSession>>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
session_source: SessionSource,
|
||||
}
|
||||
@@ -43,6 +57,7 @@ impl ConversationManager {
|
||||
pub fn new(auth_manager: Arc<AuthManager>, session_source: SessionSource) -> Self {
|
||||
Self {
|
||||
conversations: Arc::new(RwLock::new(HashMap::new())),
|
||||
cross_session_registrations: Arc::new(RwLock::new(HashMap::new())),
|
||||
auth_manager,
|
||||
session_source,
|
||||
}
|
||||
@@ -58,26 +73,104 @@ impl ConversationManager {
|
||||
}
|
||||
|
||||
pub async fn new_conversation(&self, config: Config) -> CodexResult<NewConversation> {
|
||||
self.spawn_conversation(config, self.auth_manager.clone())
|
||||
.await
|
||||
self.spawn_conversation_with_history(
|
||||
config,
|
||||
self.auth_manager.clone(),
|
||||
InitialHistory::New,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn spawn_conversation(
|
||||
pub async fn new_conversation_with_cross_session(
|
||||
&self,
|
||||
config: Config,
|
||||
params: CrossSessionSpawnParams,
|
||||
) -> CodexResult<NewConversation> {
|
||||
self.spawn_conversation_with_history(
|
||||
config,
|
||||
self.auth_manager.clone(),
|
||||
InitialHistory::New,
|
||||
Some(params),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn spawn_conversation_with_history(
|
||||
&self,
|
||||
config: Config,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
initial_history: InitialHistory,
|
||||
cross_session: Option<CrossSessionSpawnParams>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let cross_session =
|
||||
cross_session.map(|params| (SessionDefaults::from_config(&config), params));
|
||||
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(
|
||||
config,
|
||||
auth_manager,
|
||||
InitialHistory::New,
|
||||
self.session_source,
|
||||
)
|
||||
.await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
} = Codex::spawn(config, auth_manager, initial_history, self.session_source).await?;
|
||||
|
||||
let new_conversation = self.finalize_spawn(codex, conversation_id).await?;
|
||||
|
||||
if let Some((defaults, params)) = cross_session
|
||||
&& let Err(err) = self
|
||||
.register_cross_session(
|
||||
conversation_id,
|
||||
defaults,
|
||||
params,
|
||||
Arc::clone(&new_conversation.conversation),
|
||||
)
|
||||
.await
|
||||
{
|
||||
self.abort_conversation(conversation_id, Arc::clone(&new_conversation.conversation))
|
||||
.await;
|
||||
return Err(CodexErr::Fatal(format!(
|
||||
"failed to register cross-session for conversation {conversation_id}: {err}"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(new_conversation)
|
||||
}
|
||||
|
||||
async fn register_cross_session(
|
||||
&self,
|
||||
conversation_id: ConversationId,
|
||||
defaults: SessionDefaults,
|
||||
params: CrossSessionSpawnParams,
|
||||
conversation: Arc<CodexConversation>,
|
||||
) -> Result<(), CrossSessionError> {
|
||||
let CrossSessionSpawnParams { hub, run_id, role } = params;
|
||||
|
||||
let registration = SessionRegistration {
|
||||
conversation_id,
|
||||
conversation,
|
||||
defaults,
|
||||
run_id,
|
||||
role,
|
||||
};
|
||||
|
||||
let guard = hub.register_session(registration)?;
|
||||
self.cross_session_registrations
|
||||
.write()
|
||||
.await
|
||||
.insert(conversation_id, guard);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn abort_conversation(
|
||||
&self,
|
||||
conversation_id: ConversationId,
|
||||
conversation: Arc<CodexConversation>,
|
||||
) {
|
||||
let _ = self.remove_conversation(&conversation_id).await;
|
||||
if let Err(err) = conversation.submit(Op::Shutdown).await {
|
||||
warn!(
|
||||
%conversation_id,
|
||||
?err,
|
||||
"failed to shutdown conversation after cross-session registration error"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn finalize_spawn(
|
||||
@@ -130,11 +223,35 @@ impl ConversationManager {
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, initial_history, self.session_source).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
self.spawn_conversation_with_history(config, auth_manager, initial_history, None)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn resume_conversation_from_rollout_with_cross_session(
|
||||
&self,
|
||||
config: Config,
|
||||
rollout_path: PathBuf,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
params: CrossSessionSpawnParams,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
|
||||
self.spawn_conversation_with_history(config, auth_manager, initial_history, Some(params))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn resume_conversation_with_cross_session(
|
||||
&self,
|
||||
config: Config,
|
||||
rollout_path: PathBuf,
|
||||
params: CrossSessionSpawnParams,
|
||||
) -> CodexResult<NewConversation> {
|
||||
self.resume_conversation_from_rollout_with_cross_session(
|
||||
config,
|
||||
rollout_path,
|
||||
self.auth_manager.clone(),
|
||||
params,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Removes the conversation from the manager's internal map, though the
|
||||
@@ -145,6 +262,10 @@ impl ConversationManager {
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
) -> Option<Arc<CodexConversation>> {
|
||||
self.cross_session_registrations
|
||||
.write()
|
||||
.await
|
||||
.remove(conversation_id);
|
||||
self.conversations.write().await.remove(conversation_id)
|
||||
}
|
||||
|
||||
@@ -164,12 +285,23 @@ impl ConversationManager {
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, history, self.session_source).await?;
|
||||
self.spawn_conversation_with_history(config, auth_manager, history, None)
|
||||
.await
|
||||
}
|
||||
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
pub async fn fork_conversation_with_cross_session(
|
||||
&self,
|
||||
nth_user_message: usize,
|
||||
config: Config,
|
||||
path: PathBuf,
|
||||
params: CrossSessionSpawnParams,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let history = RolloutRecorder::get_rollout_history(&path).await?;
|
||||
let history = truncate_before_nth_user_message(history, nth_user_message);
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
self.spawn_conversation_with_history(config, auth_manager, history, Some(params))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
607
codex-rs/core/src/cross_session.rs
Normal file
607
codex-rs/core/src/cross_session.rs
Normal file
@@ -0,0 +1,607 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::path::PathBuf;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::RwLockReadGuard;
|
||||
use std::sync::RwLockWriteGuard;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::Stream;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex_conversation::CodexConversation;
|
||||
use crate::config::Config;
|
||||
use crate::error::CodexErr;
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol_config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use crate::protocol_config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::ConversationId;
|
||||
|
||||
/// Default capacity for broadcast channels that fan out session events.
|
||||
const EVENT_BUFFER_LEN: usize = 256;
|
||||
|
||||
/// Encapsulates the defaults needed to submit a new `Op::UserTurn`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionDefaults {
|
||||
pub cwd: PathBuf,
|
||||
pub approval_policy: AskForApproval,
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
pub model: String,
|
||||
pub effort: Option<ReasoningEffortConfig>,
|
||||
pub summary: ReasoningSummaryConfig,
|
||||
}
|
||||
|
||||
impl SessionDefaults {
|
||||
pub fn from_config(config: &Config) -> Self {
|
||||
Self {
|
||||
cwd: config.cwd.clone(),
|
||||
approval_policy: config.approval_policy,
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
model: config.model.clone(),
|
||||
effort: config.model_reasoning_effort,
|
||||
summary: config.model_reasoning_summary,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request payload for posting a user turn to a session.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PostUserTurnRequest {
|
||||
pub target: RoleOrId,
|
||||
pub text: String,
|
||||
pub final_output_json_schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// Identifier used when targeting sessions for cross-session routing.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RoleOrId {
|
||||
Session(ConversationId),
|
||||
RunRole { run_id: String, role: String },
|
||||
}
|
||||
|
||||
/// Handle returned by [`CrossSessionHub::post_user_turn`].
|
||||
pub struct TurnHandle {
|
||||
conversation_id: ConversationId,
|
||||
submission_id: String,
|
||||
receiver: TokioMutex<Option<oneshot::Receiver<AssistantMessage>>>,
|
||||
}
|
||||
|
||||
impl TurnHandle {
|
||||
pub fn conversation_id(&self) -> ConversationId {
|
||||
self.conversation_id
|
||||
}
|
||||
|
||||
pub fn submission_id(&self) -> &str {
|
||||
&self.submission_id
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for TurnHandle {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("TurnHandle")
|
||||
.field("conversation_id", &self.conversation_id)
|
||||
.field("submission_id", &self.submission_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// First assistant message emitted for a bridged turn.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AssistantMessage {
|
||||
pub conversation_id: ConversationId,
|
||||
pub submission_id: String,
|
||||
pub message: AgentMessageEvent,
|
||||
}
|
||||
|
||||
/// Wrapper around a session event tagged with its conversation id.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionEvent {
|
||||
pub conversation_id: ConversationId,
|
||||
pub event: Event,
|
||||
}
|
||||
|
||||
/// Stream of [`SessionEvent`] instances for a particular session.
|
||||
pub struct SessionEventStream {
|
||||
inner: BroadcastStream<SessionEvent>,
|
||||
}
|
||||
|
||||
impl SessionEventStream {
|
||||
fn new(receiver: broadcast::Receiver<SessionEvent>) -> Self {
|
||||
Self {
|
||||
inner: BroadcastStream::new(receiver),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for SessionEventStream {
|
||||
type Item = SessionEvent;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
std::task::Poll::Ready(Some(Ok(event))) => {
|
||||
return std::task::Poll::Ready(Some(event));
|
||||
}
|
||||
std::task::Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => continue,
|
||||
std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
|
||||
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RoleKey {
|
||||
run_id: Arc<str>,
|
||||
role: Arc<str>,
|
||||
}
|
||||
|
||||
impl RoleKey {
|
||||
fn new(run_id: String, role: String) -> Self {
|
||||
Self {
|
||||
run_id: Arc::<str>::from(run_id),
|
||||
role: Arc::<str>::from(role),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for RoleKey {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.run_id.as_ref() == other.run_id.as_ref() && self.role.as_ref() == other.role.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for RoleKey {}
|
||||
|
||||
impl std::hash::Hash for RoleKey {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
std::hash::Hash::hash(self.run_id.as_ref(), state);
|
||||
std::hash::Hash::hash(self.role.as_ref(), state);
|
||||
}
|
||||
}
|
||||
|
||||
struct SessionEntry {
|
||||
conversation_id: ConversationId,
|
||||
conversation: Arc<CodexConversation>,
|
||||
defaults: SessionDefaults,
|
||||
role_key: Option<RoleKey>,
|
||||
event_tx: broadcast::Sender<SessionEvent>,
|
||||
turn_watchers: TokioMutex<HashMap<String, oneshot::Sender<AssistantMessage>>>,
|
||||
pending_messages: TokioMutex<HashMap<String, AssistantMessage>>,
|
||||
shutdown_tx: StdMutex<Option<oneshot::Sender<()>>>,
|
||||
}
|
||||
|
||||
impl SessionEntry {
|
||||
fn new(
|
||||
conversation_id: ConversationId,
|
||||
conversation: Arc<CodexConversation>,
|
||||
defaults: SessionDefaults,
|
||||
role_key: Option<RoleKey>,
|
||||
event_tx: broadcast::Sender<SessionEvent>,
|
||||
shutdown_tx: oneshot::Sender<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
conversation_id,
|
||||
conversation,
|
||||
defaults,
|
||||
role_key,
|
||||
event_tx,
|
||||
turn_watchers: TokioMutex::new(HashMap::new()),
|
||||
pending_messages: TokioMutex::new(HashMap::new()),
|
||||
shutdown_tx: StdMutex::new(Some(shutdown_tx)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn register_waiter(
|
||||
&self,
|
||||
submission_id: String,
|
||||
sender: oneshot::Sender<AssistantMessage>,
|
||||
) {
|
||||
{
|
||||
let mut watchers = self.turn_watchers.lock().await;
|
||||
if let Some(message) = {
|
||||
let mut pending = self.pending_messages.lock().await;
|
||||
pending.remove(&submission_id)
|
||||
} {
|
||||
drop(watchers);
|
||||
let _ = sender.send(message);
|
||||
return;
|
||||
}
|
||||
watchers.insert(submission_id, sender);
|
||||
}
|
||||
}
|
||||
|
||||
async fn notify_assistant_message(&self, message: AssistantMessage) {
|
||||
let submission_id = message.submission_id.clone();
|
||||
let sender_opt = {
|
||||
let mut watchers = self.turn_watchers.lock().await;
|
||||
watchers.remove(&submission_id)
|
||||
};
|
||||
|
||||
if let Some(sender) = sender_opt {
|
||||
let _ = sender.send(message);
|
||||
} else {
|
||||
let mut pending = self.pending_messages.lock().await;
|
||||
pending.entry(submission_id).or_insert(message);
|
||||
}
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
|
||||
self.event_tx.subscribe()
|
||||
}
|
||||
|
||||
fn close(&self) {
|
||||
if let Ok(mut guard) = self.shutdown_tx.lock()
|
||||
&& let Some(tx) = guard.take()
|
||||
{
|
||||
let _ = tx.send(());
|
||||
}
|
||||
}
|
||||
|
||||
fn role_key(&self) -> Option<RoleKey> {
|
||||
self.role_key.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Input for registering a session with the hub.
|
||||
pub struct SessionRegistration {
|
||||
pub conversation_id: ConversationId,
|
||||
pub conversation: Arc<CodexConversation>,
|
||||
pub defaults: SessionDefaults,
|
||||
pub run_id: Option<String>,
|
||||
pub role: Option<String>,
|
||||
}
|
||||
|
||||
/// Guard that unregisters the session on drop.
|
||||
pub struct RegisteredSession {
|
||||
inner: Arc<Inner>,
|
||||
conversation_id: ConversationId,
|
||||
}
|
||||
|
||||
impl RegisteredSession {
|
||||
pub fn conversation_id(&self) -> ConversationId {
|
||||
self.conversation_id
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RegisteredSession {
|
||||
fn drop(&mut self) {
|
||||
self.inner.unregister(self.conversation_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Inner {
|
||||
sessions: RwLock<HashMap<ConversationId, Arc<SessionEntry>>>,
|
||||
roles: RwLock<HashMap<RoleKey, ConversationId>>,
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
fn sessions_read(
|
||||
&self,
|
||||
) -> Result<RwLockReadGuard<'_, HashMap<ConversationId, Arc<SessionEntry>>>, CrossSessionError>
|
||||
{
|
||||
self.sessions
|
||||
.read()
|
||||
.map_err(|_| CrossSessionError::LockPoisoned("sessions"))
|
||||
}
|
||||
|
||||
fn sessions_write(
|
||||
&self,
|
||||
) -> Result<RwLockWriteGuard<'_, HashMap<ConversationId, Arc<SessionEntry>>>, CrossSessionError>
|
||||
{
|
||||
self.sessions
|
||||
.write()
|
||||
.map_err(|_| CrossSessionError::LockPoisoned("sessions"))
|
||||
}
|
||||
|
||||
fn roles_read(
|
||||
&self,
|
||||
) -> Result<RwLockReadGuard<'_, HashMap<RoleKey, ConversationId>>, CrossSessionError> {
|
||||
self.roles
|
||||
.read()
|
||||
.map_err(|_| CrossSessionError::LockPoisoned("roles"))
|
||||
}
|
||||
|
||||
fn roles_write(
|
||||
&self,
|
||||
) -> Result<RwLockWriteGuard<'_, HashMap<RoleKey, ConversationId>>, CrossSessionError> {
|
||||
self.roles
|
||||
.write()
|
||||
.map_err(|_| CrossSessionError::LockPoisoned("roles"))
|
||||
}
|
||||
|
||||
fn insert(&self, entry: Arc<SessionEntry>) -> Result<(), CrossSessionError> {
|
||||
{
|
||||
let mut sessions = self.sessions_write()?;
|
||||
if sessions
|
||||
.insert(entry.conversation_id, entry.clone())
|
||||
.is_some()
|
||||
{
|
||||
return Err(CrossSessionError::SessionAlreadyRegistered(
|
||||
entry.conversation_id,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(role_key) = entry.role_key() {
|
||||
let mut roles = self.roles_write()?;
|
||||
if roles.contains_key(&role_key) {
|
||||
self.sessions_write()?.remove(&entry.conversation_id);
|
||||
return Err(CrossSessionError::RoleAlreadyRegistered {
|
||||
run_id: role_key.run_id.to_string(),
|
||||
role: role_key.role.to_string(),
|
||||
});
|
||||
}
|
||||
roles.insert(role_key, entry.conversation_id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unregister(&self, conversation_id: ConversationId) {
|
||||
if let Some(entry) = self.remove_internal(conversation_id) {
|
||||
entry.close();
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_internal(&self, conversation_id: ConversationId) -> Option<Arc<SessionEntry>> {
|
||||
let (entry, role_key) = {
|
||||
let mut sessions = self.sessions.write().ok()?;
|
||||
let entry = sessions.remove(&conversation_id)?;
|
||||
let role_key = entry.role_key();
|
||||
(entry, role_key)
|
||||
};
|
||||
|
||||
if let Some(role_key) = role_key
|
||||
&& let Ok(mut roles) = self.roles.write()
|
||||
{
|
||||
roles.remove(&role_key);
|
||||
}
|
||||
|
||||
Some(entry)
|
||||
}
|
||||
|
||||
fn resolve_session(
|
||||
&self,
|
||||
conversation_id: ConversationId,
|
||||
) -> Result<Arc<SessionEntry>, CrossSessionError> {
|
||||
self.sessions_read()?
|
||||
.get(&conversation_id)
|
||||
.cloned()
|
||||
.ok_or(CrossSessionError::SessionNotFound(conversation_id))
|
||||
}
|
||||
|
||||
fn resolve_target(&self, target: &RoleOrId) -> Result<Arc<SessionEntry>, CrossSessionError> {
|
||||
match target {
|
||||
RoleOrId::Session(id) => self.resolve_session(*id),
|
||||
RoleOrId::RunRole { run_id, role } => {
|
||||
let conversation_id = {
|
||||
let roles = self.roles_read()?;
|
||||
let key = RoleKey::new(run_id.clone(), role.clone());
|
||||
roles
|
||||
.get(&key)
|
||||
.copied()
|
||||
.ok_or_else(|| CrossSessionError::RoleNotFound {
|
||||
run_id: run_id.clone(),
|
||||
role: role.clone(),
|
||||
})?
|
||||
};
|
||||
self.resolve_session(conversation_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cross-session coordination hub.
|
||||
#[derive(Default, Clone)]
|
||||
pub struct CrossSessionHub {
|
||||
inner: Arc<Inner>,
|
||||
}
|
||||
|
||||
impl CrossSessionHub {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn register_session(
|
||||
&self,
|
||||
registration: SessionRegistration,
|
||||
) -> Result<RegisteredSession, CrossSessionError> {
|
||||
let SessionRegistration {
|
||||
conversation_id,
|
||||
conversation,
|
||||
defaults,
|
||||
run_id,
|
||||
role,
|
||||
} = registration;
|
||||
|
||||
let role_key = match (run_id, role) {
|
||||
(Some(run_id), Some(role)) => Some(RoleKey::new(run_id, role)),
|
||||
(None, None) => None,
|
||||
_ => {
|
||||
return Err(CrossSessionError::IncompleteRoleRegistration);
|
||||
}
|
||||
};
|
||||
|
||||
let (event_tx, _) = broadcast::channel(EVENT_BUFFER_LEN);
|
||||
let (shutdown_tx, shutdown_rx) = oneshot::channel();
|
||||
let entry = Arc::new(SessionEntry::new(
|
||||
conversation_id,
|
||||
Arc::clone(&conversation),
|
||||
defaults,
|
||||
role_key,
|
||||
event_tx,
|
||||
shutdown_tx,
|
||||
));
|
||||
|
||||
self.inner.insert(entry.clone())?;
|
||||
|
||||
self.spawn_event_forwarder(entry, conversation, shutdown_rx);
|
||||
|
||||
Ok(RegisteredSession {
|
||||
inner: Arc::clone(&self.inner),
|
||||
conversation_id,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn post_user_turn(
|
||||
&self,
|
||||
request: PostUserTurnRequest,
|
||||
) -> Result<TurnHandle, CrossSessionError> {
|
||||
let entry = self.inner.resolve_target(&request.target)?;
|
||||
|
||||
let items = vec![InputItem::Text { text: request.text }];
|
||||
|
||||
let defaults = &entry.defaults;
|
||||
let submission_id = entry
|
||||
.conversation
|
||||
.submit(Op::UserTurn {
|
||||
items,
|
||||
cwd: defaults.cwd.clone(),
|
||||
approval_policy: defaults.approval_policy,
|
||||
sandbox_policy: defaults.sandbox_policy.clone(),
|
||||
model: defaults.model.clone(),
|
||||
effort: defaults.effort,
|
||||
summary: defaults.summary,
|
||||
final_output_json_schema: request.final_output_json_schema,
|
||||
})
|
||||
.await
|
||||
.map_err(CrossSessionError::from)?;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
entry.register_waiter(submission_id.clone(), tx).await;
|
||||
|
||||
Ok(TurnHandle {
|
||||
conversation_id: entry.conversation_id,
|
||||
submission_id,
|
||||
receiver: TokioMutex::new(Some(rx)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn await_first_assistant(
|
||||
&self,
|
||||
handle: &TurnHandle,
|
||||
timeout: Duration,
|
||||
) -> Result<AssistantMessage, CrossSessionError> {
|
||||
let receiver = {
|
||||
let mut guard = handle.receiver.lock().await;
|
||||
guard.take().ok_or(CrossSessionError::TurnHandleConsumed)?
|
||||
};
|
||||
|
||||
match time::timeout(timeout, receiver).await {
|
||||
Ok(Ok(message)) => Ok(message),
|
||||
Ok(Err(_)) => Err(CrossSessionError::SessionClosed),
|
||||
Err(_) => Err(CrossSessionError::AwaitTimeout(timeout)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream_events(
|
||||
&self,
|
||||
conversation_id: ConversationId,
|
||||
) -> Result<SessionEventStream, CrossSessionError> {
|
||||
let entry = self.inner.resolve_session(conversation_id)?;
|
||||
Ok(SessionEventStream::new(entry.subscribe()))
|
||||
}
|
||||
|
||||
fn spawn_event_forwarder(
|
||||
&self,
|
||||
entry: Arc<SessionEntry>,
|
||||
conversation: Arc<CodexConversation>,
|
||||
mut shutdown_rx: oneshot::Receiver<()>,
|
||||
) {
|
||||
let conversation_id = entry.conversation_id;
|
||||
let event_tx = entry.event_tx.clone();
|
||||
let inner = Arc::clone(&self.inner);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = &mut shutdown_rx => {
|
||||
debug!("CrossSessionHub received shutdown for session {conversation_id}");
|
||||
break;
|
||||
}
|
||||
event = conversation.next_event() => {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
if let EventMsg::AgentMessage(agent_message) = &event.msg {
|
||||
let message = AssistantMessage {
|
||||
conversation_id,
|
||||
submission_id: event.id.clone(),
|
||||
message: agent_message.clone(),
|
||||
};
|
||||
entry.notify_assistant_message(message).await;
|
||||
}
|
||||
|
||||
if let Err(err) = event_tx.send(SessionEvent {
|
||||
conversation_id,
|
||||
event: event.clone(),
|
||||
}) {
|
||||
debug!(
|
||||
"CrossSessionHub dropped event for session {conversation_id}: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
if matches!(event.msg, EventMsg::ShutdownComplete) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!("CrossSessionHub event loop error for session {conversation_id}: {err:#?}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inner.unregister(conversation_id);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors surfaced by cross-session orchestration.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CrossSessionError {
|
||||
#[error("session {0} is already registered with the hub")]
|
||||
SessionAlreadyRegistered(ConversationId),
|
||||
#[error("run {run_id} already has a {role} session registered")]
|
||||
RoleAlreadyRegistered { run_id: String, role: String },
|
||||
#[error("session {0} does not exist")]
|
||||
SessionNotFound(ConversationId),
|
||||
#[error("no session registered for run {run_id} role {role}")]
|
||||
RoleNotFound { run_id: String, role: String },
|
||||
#[error("session role registration must set both run_id and role")]
|
||||
IncompleteRoleRegistration,
|
||||
#[error("turn handle has already been awaited")]
|
||||
TurnHandleConsumed,
|
||||
#[error("session closed before an assistant message was emitted")]
|
||||
SessionClosed,
|
||||
#[error("timed out waiting {0:?} for assistant response")]
|
||||
AwaitTimeout(Duration),
|
||||
#[error("internal lock poisoned: {0}")]
|
||||
LockPoisoned(&'static str),
|
||||
#[error("submit failed: {0}")]
|
||||
SubmitFailed(#[from] CodexErr),
|
||||
}
|
||||
@@ -91,6 +91,12 @@ pub enum CodexErr {
|
||||
#[error("{0}")]
|
||||
UsageLimitReached(UsageLimitReachedError),
|
||||
|
||||
#[error("{0}")]
|
||||
ResponseStreamFailed(ResponseStreamFailed),
|
||||
|
||||
#[error("{0}")]
|
||||
ConnectionFailed(ConnectionFailedError),
|
||||
|
||||
#[error(
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
)]
|
||||
@@ -126,9 +132,6 @@ pub enum CodexErr {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
@@ -147,6 +150,37 @@ pub enum CodexErr {
|
||||
EnvVar(EnvVarError),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionFailedError {
|
||||
pub source: reqwest::Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ConnectionFailedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Connection failed: {}", self.source)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ResponseStreamFailed {
|
||||
pub source: reqwest::Error,
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ResponseStreamFailed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Error while reading the server response: {}{}",
|
||||
self.source,
|
||||
self.request_id
|
||||
.as_ref()
|
||||
.map(|id| format!(", request id: {id}"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UnexpectedResponseError {
|
||||
pub status: StatusCode,
|
||||
|
||||
@@ -60,5 +60,9 @@ pub mod errors {
|
||||
pub(crate) fn rejection(msg: impl Into<String>) -> Self {
|
||||
FunctionCallError::RespondToModel(msg.into()).into()
|
||||
}
|
||||
|
||||
pub(crate) fn denied(msg: impl Into<String>) -> Self {
|
||||
FunctionCallError::Denied(msg.into()).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
@@ -74,13 +75,18 @@ impl Executor {
|
||||
/// Runs a prepared execution request end-to-end: prepares parameters, decides on
|
||||
/// sandbox placement (prompting the user when necessary), launches the command,
|
||||
/// and lets the backend post-process the final output.
|
||||
pub(crate) async fn run(
|
||||
pub(crate) async fn run<F, Fut>(
|
||||
&self,
|
||||
mut request: ExecutionRequest,
|
||||
session: &Session,
|
||||
approval_policy: AskForApproval,
|
||||
context: &ExecCommandContext,
|
||||
) -> Result<ExecToolCallOutput, ExecError> {
|
||||
on_exec_begin: F,
|
||||
) -> Result<ExecToolCallOutput, ExecError>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = ()>,
|
||||
{
|
||||
if matches!(request.mode, ExecutionMode::Shell) {
|
||||
request.params =
|
||||
maybe_translate_shell_command(request.params, session, request.use_shell_profile);
|
||||
@@ -119,7 +125,7 @@ impl Executor {
|
||||
if sandbox_decision.record_session_approval {
|
||||
self.approval_cache.insert(request.approval_command.clone());
|
||||
}
|
||||
|
||||
on_exec_begin().await;
|
||||
// Step 4: Launch the command within the chosen sandbox.
|
||||
let first_attempt = self
|
||||
.spawn(
|
||||
@@ -210,7 +216,7 @@ impl Executor {
|
||||
Ok(retry_output)
|
||||
}
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
Err(ExecError::rejection("exec command rejected by user"))
|
||||
Err(ExecError::denied("exec command rejected by user"))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -301,7 +307,8 @@ pub(crate) fn normalize_exec_result(
|
||||
}
|
||||
Err(err) => {
|
||||
let message = match err {
|
||||
ExecError::Function(FunctionCallError::RespondToModel(msg)) => msg.clone(),
|
||||
ExecError::Function(FunctionCallError::RespondToModel(msg))
|
||||
| ExecError::Function(FunctionCallError::Denied(msg)) => msg.clone(),
|
||||
ExecError::Codex(e) => get_error_message_ui(e),
|
||||
err => err.to_string(),
|
||||
};
|
||||
|
||||
@@ -149,7 +149,7 @@ async fn select_shell_sandbox(
|
||||
ReviewDecision::Approved => Ok(SandboxDecision::user_override(false)),
|
||||
ReviewDecision::ApprovedForSession => Ok(SandboxDecision::user_override(true)),
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
Err(ExecError::rejection("exec command rejected by user"))
|
||||
Err(ExecError::denied("exec command rejected by user"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,8 @@ pub enum Feature {
|
||||
ViewImageTool,
|
||||
/// Allow the model to request web searches.
|
||||
WebSearchRequest,
|
||||
/// Automatically approve all approval requests from the harness.
|
||||
ApproveAll,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
@@ -247,4 +249,10 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::Stable,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ApproveAll,
|
||||
key: "approve_all",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -4,6 +4,8 @@ use thiserror::Error;
|
||||
pub enum FunctionCallError {
|
||||
#[error("{0}")]
|
||||
RespondToModel(String),
|
||||
#[error("{0}")]
|
||||
Denied(String),
|
||||
#[error("LocalShellCall without call_id or id")]
|
||||
MissingLocalShellCallId,
|
||||
#[error("Fatal error: {0}")]
|
||||
|
||||
@@ -13,6 +13,7 @@ mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
mod codex_conversation;
|
||||
pub mod cross_session;
|
||||
pub mod token_data;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
mod command_safety;
|
||||
@@ -52,6 +53,7 @@ mod event_mapping;
|
||||
pub mod review_format;
|
||||
pub use codex_protocol::protocol::InitialHistory;
|
||||
pub use conversation_manager::ConversationManager;
|
||||
pub use conversation_manager::CrossSessionSpawnParams;
|
||||
pub use conversation_manager::NewConversation;
|
||||
// Re-export common auth types for workspace consumers
|
||||
pub use auth::AuthManager;
|
||||
|
||||
@@ -45,11 +45,15 @@ async fn compute_auth_status(
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
determine_streamable_http_auth_status(
|
||||
server_name,
|
||||
url,
|
||||
bearer_token_env_var.as_deref(),
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -21,6 +22,14 @@ use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use mcp_types::Resource;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use mcp_types::Tool;
|
||||
|
||||
use serde_json::json;
|
||||
@@ -102,36 +111,51 @@ enum McpClientAdapter {
|
||||
}
|
||||
|
||||
impl McpClientAdapter {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn new_stdio_client(
|
||||
use_rmcp_client: bool,
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
env_vars: Vec<String>,
|
||||
cwd: Option<PathBuf>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
if use_rmcp_client {
|
||||
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?);
|
||||
let client =
|
||||
Arc::new(RmcpClient::new_stdio_client(program, args, env, &env_vars, cwd).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
} else {
|
||||
let client = Arc::new(McpClient::new_stdio_client(program, args, env).await?);
|
||||
let client =
|
||||
Arc::new(McpClient::new_stdio_client(program, args, env, &env_vars, cwd).await?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Legacy(client))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn new_streamable_http_client(
|
||||
server_name: String,
|
||||
url: String,
|
||||
bearer_token: Option<String>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let client = Arc::new(
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode)
|
||||
.await?,
|
||||
RmcpClient::new_streamable_http_client(
|
||||
&server_name,
|
||||
&url,
|
||||
bearer_token,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
@@ -148,6 +172,47 @@ impl McpClientAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resources(
|
||||
&self,
|
||||
params: Option<mcp_types::ListResourcesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::ListResourcesResult> {
|
||||
match self {
|
||||
McpClientAdapter::Legacy(_) => Ok(ListResourcesResult {
|
||||
next_cursor: None,
|
||||
resources: Vec::new(),
|
||||
}),
|
||||
McpClientAdapter::Rmcp(client) => client.list_resources(params, timeout).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
params: mcp_types::ReadResourceRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::ReadResourceResult> {
|
||||
match self {
|
||||
McpClientAdapter::Legacy(_) => Err(anyhow!(
|
||||
"resources/read is not supported by legacy MCP clients"
|
||||
)),
|
||||
McpClientAdapter::Rmcp(client) => client.read_resource(params, timeout).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
params: Option<mcp_types::ListResourceTemplatesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<mcp_types::ListResourceTemplatesResult> {
|
||||
match self {
|
||||
McpClientAdapter::Legacy(_) => Ok(ListResourceTemplatesResult {
|
||||
next_cursor: None,
|
||||
resource_templates: Vec::new(),
|
||||
}),
|
||||
McpClientAdapter::Rmcp(client) => client.list_resource_templates(params, timeout).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
name: String,
|
||||
@@ -246,7 +311,13 @@ impl McpConnectionManager {
|
||||
};
|
||||
|
||||
let client = match transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
let command_os: OsString = command.into();
|
||||
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
|
||||
McpClientAdapter::new_stdio_client(
|
||||
@@ -254,16 +325,25 @@ impl McpConnectionManager {
|
||||
command_os,
|
||||
args_os,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
params,
|
||||
startup_timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, .. } => {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => {
|
||||
McpClientAdapter::new_streamable_http_client(
|
||||
server_name.clone(),
|
||||
url,
|
||||
resolved_bearer_token.unwrap_or_default(),
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
params,
|
||||
startup_timeout,
|
||||
store_mode,
|
||||
@@ -318,7 +398,7 @@ impl McpConnectionManager {
|
||||
Ok((Self { clients, tools }, errors))
|
||||
}
|
||||
|
||||
/// Returns a single map that contains **all** tools. Each key is the
|
||||
/// Returns a single map that contains all tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
|
||||
self.tools
|
||||
@@ -327,6 +407,133 @@ impl McpConnectionManager {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all resources. Each key is the
|
||||
/// server name and the value is a vector of resources.
|
||||
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for (server_name, managed_client) in &self.clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let timeout = managed_client.tool_timeout;
|
||||
|
||||
join_set.spawn(async move {
|
||||
let mut collected: Vec<Resource> = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let params = cursor.as_ref().map(|next| ListResourcesRequestParams {
|
||||
cursor: Some(next.clone()),
|
||||
});
|
||||
let response = match client_clone.list_resources(params, timeout).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => return (server_name_cloned, Err(err)),
|
||||
};
|
||||
|
||||
collected.extend(response.resources);
|
||||
|
||||
match response.next_cursor {
|
||||
Some(next) => {
|
||||
if cursor.as_ref() == Some(&next) {
|
||||
return (
|
||||
server_name_cloned,
|
||||
Err(anyhow!("resources/list returned duplicate cursor")),
|
||||
);
|
||||
}
|
||||
cursor = Some(next);
|
||||
}
|
||||
None => return (server_name_cloned, Ok(collected)),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: HashMap<String, Vec<Resource>> = HashMap::new();
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
match join_res {
|
||||
Ok((server_name, Ok(resources))) => {
|
||||
aggregated.insert(server_name, resources);
|
||||
}
|
||||
Ok((server_name, Err(err))) => {
|
||||
warn!("Failed to list resources for MCP server '{server_name}': {err:#}");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Task panic when listing resources for MCP server: {err:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
aggregated
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all resource templates. Each key is the
|
||||
/// server name and the value is a vector of resource templates.
|
||||
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for (server_name, managed_client) in &self.clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let timeout = managed_client.tool_timeout;
|
||||
|
||||
join_set.spawn(async move {
|
||||
let mut collected: Vec<ResourceTemplate> = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let params = cursor
|
||||
.as_ref()
|
||||
.map(|next| ListResourceTemplatesRequestParams {
|
||||
cursor: Some(next.clone()),
|
||||
});
|
||||
let response = match client_clone.list_resource_templates(params, timeout).await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => return (server_name_cloned, Err(err)),
|
||||
};
|
||||
|
||||
collected.extend(response.resource_templates);
|
||||
|
||||
match response.next_cursor {
|
||||
Some(next) => {
|
||||
if cursor.as_ref() == Some(&next) {
|
||||
return (
|
||||
server_name_cloned,
|
||||
Err(anyhow!(
|
||||
"resources/templates/list returned duplicate cursor"
|
||||
)),
|
||||
);
|
||||
}
|
||||
cursor = Some(next);
|
||||
}
|
||||
None => return (server_name_cloned, Ok(collected)),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: HashMap<String, Vec<ResourceTemplate>> = HashMap::new();
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
match join_res {
|
||||
Ok((server_name, Ok(templates))) => {
|
||||
aggregated.insert(server_name, templates);
|
||||
}
|
||||
Ok((server_name, Err(err))) => {
|
||||
warn!(
|
||||
"Failed to list resource templates for MCP server '{server_name}': {err:#}"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Task panic when listing resource templates for MCP server: {err:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
aggregated
|
||||
}
|
||||
|
||||
/// Invoke the tool indicated by the (server, tool) pair.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
@@ -338,7 +545,7 @@ impl McpConnectionManager {
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let client = &managed.client;
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
@@ -347,6 +554,64 @@ impl McpConnectionManager {
|
||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourcesRequestParams>,
|
||||
) -> Result<ListResourcesResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
.list_resources(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/list failed for `{server}`"))
|
||||
}
|
||||
|
||||
/// List resource templates from the specified server.
|
||||
pub async fn list_resource_templates(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
) -> Result<ListResourceTemplatesResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
.list_resource_templates(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/templates/list failed for `{server}`"))
|
||||
}
|
||||
|
||||
/// Read a resource from the specified server.
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
) -> Result<ReadResourceResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
let uri = params.uri.clone();
|
||||
|
||||
client
|
||||
.read_resource(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
|
||||
}
|
||||
|
||||
pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.tools
|
||||
.get(tool_name)
|
||||
@@ -382,7 +647,7 @@ fn resolve_bearer_token(
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
/// contains all tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
|
||||
@@ -68,7 +68,11 @@ pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
auto_compact_token_limit: Some(350_000),
|
||||
}),
|
||||
|
||||
_ if slug.starts_with("gpt-5") => Some(ModelInfo::new(272_000, 128_000)),
|
||||
_ if slug.starts_with("gpt-5") => Some(ModelInfo {
|
||||
context_window: 272_000,
|
||||
max_output_tokens: 128_000,
|
||||
auto_compact_token_limit: Some(250_000),
|
||||
}),
|
||||
|
||||
_ if slug.starts_with("codex-") => Some(ModelInfo::new(272_000, 128_000)),
|
||||
|
||||
|
||||
@@ -1,43 +1,9 @@
|
||||
use crate::bash::try_parse_bash;
|
||||
use crate::bash::try_parse_word_only_commands_sequence;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use codex_protocol::parse_command::ParsedCommand;
|
||||
use shlex::split as shlex_split;
|
||||
use shlex::try_join as shlex_try_join;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub enum ParsedCommand {
|
||||
Read {
|
||||
cmd: String,
|
||||
name: String,
|
||||
},
|
||||
ListFiles {
|
||||
cmd: String,
|
||||
path: Option<String>,
|
||||
},
|
||||
Search {
|
||||
cmd: String,
|
||||
query: Option<String>,
|
||||
path: Option<String>,
|
||||
},
|
||||
Unknown {
|
||||
cmd: String,
|
||||
},
|
||||
}
|
||||
|
||||
// Convert core's parsed command enum into the protocol's simplified type so
|
||||
// events can carry the canonical representation across process boundaries.
|
||||
impl From<ParsedCommand> for codex_protocol::parse_command::ParsedCommand {
|
||||
fn from(v: ParsedCommand) -> Self {
|
||||
use codex_protocol::parse_command::ParsedCommand as P;
|
||||
match v {
|
||||
ParsedCommand::Read { cmd, name } => P::Read { cmd, name },
|
||||
ParsedCommand::ListFiles { cmd, path } => P::ListFiles { cmd, path },
|
||||
ParsedCommand::Search { cmd, query, path } => P::Search { cmd, query, path },
|
||||
ParsedCommand::Unknown { cmd } => P::Unknown { cmd },
|
||||
}
|
||||
}
|
||||
}
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn shlex_join(tokens: &[String]) -> String {
|
||||
shlex_try_join(tokens.iter().map(String::as_str))
|
||||
@@ -72,6 +38,7 @@ pub fn parse_command(command: &[String]) -> Vec<ParsedCommand> {
|
||||
/// Tests are at the top to encourage using TDD + Codex to fix the implementation.
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
use std::string::ToString;
|
||||
|
||||
fn shlex_split_safe(s: &str) -> Vec<String> {
|
||||
@@ -221,6 +188,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "README.md".to_string(),
|
||||
path: PathBuf::from("webview/README.md"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -232,6 +200,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat foo.txt".to_string(),
|
||||
name: "foo.txt".to_string(),
|
||||
path: PathBuf::from("foo/foo.txt"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -254,6 +223,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat foo.txt".to_string(),
|
||||
name: "foo.txt".to_string(),
|
||||
path: PathBuf::from("foo/foo.txt"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -278,6 +248,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -290,6 +261,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("tui/Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -302,6 +274,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "README.md".to_string(),
|
||||
path: PathBuf::from("README.md"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -315,6 +288,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "README.md".to_string(),
|
||||
path: PathBuf::from("README.md"),
|
||||
}]
|
||||
);
|
||||
}
|
||||
@@ -484,6 +458,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "parse_command.rs".to_string(),
|
||||
path: PathBuf::from("core/src/parse_command.rs"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -496,6 +471,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: inner.to_string(),
|
||||
name: "history_cell.rs".to_string(),
|
||||
path: PathBuf::from("tui/src/history_cell.rs"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -509,6 +485,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat -- ansi-escape/Cargo.toml".to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("ansi-escape/Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -538,6 +515,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "sed -n '260,640p' exec/src/event_processor_with_human_output.rs".to_string(),
|
||||
name: "event_processor_with_human_output.rs".to_string(),
|
||||
path: PathBuf::from("exec/src/event_processor_with_human_output.rs"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -697,6 +675,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: r#"cat "pkg\\src\\main.rs""#.to_string(),
|
||||
name: "main.rs".to_string(),
|
||||
path: PathBuf::from(r#"pkg\src\main.rs"#),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -708,6 +687,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "head -n50 Cargo.toml".to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -738,6 +718,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "tail -n+10 README.md".to_string(),
|
||||
name: "README.md".to_string(),
|
||||
path: PathBuf::from("README.md"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -774,6 +755,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "cat -- ./-strange-file-name".to_string(),
|
||||
name: "-strange-file-name".to_string(),
|
||||
path: PathBuf::from("./-strange-file-name"),
|
||||
}],
|
||||
);
|
||||
|
||||
@@ -783,6 +765,7 @@ mod tests {
|
||||
vec![ParsedCommand::Read {
|
||||
cmd: "sed -n '12,20p' Cargo.toml".to_string(),
|
||||
name: "Cargo.toml".to_string(),
|
||||
path: PathBuf::from("Cargo.toml"),
|
||||
}],
|
||||
);
|
||||
}
|
||||
@@ -875,11 +858,39 @@ pub fn parse_command_impl(command: &[String]) -> Vec<ParsedCommand> {
|
||||
// Preserve left-to-right execution order for all commands, including bash -c/-lc
|
||||
// so summaries reflect the order they will run.
|
||||
|
||||
// Map each pipeline segment to its parsed summary.
|
||||
let mut commands: Vec<ParsedCommand> = parts
|
||||
.iter()
|
||||
.map(|tokens| summarize_main_tokens(tokens))
|
||||
.collect();
|
||||
// Map each pipeline segment to its parsed summary, tracking `cd` to compute paths.
|
||||
let mut commands: Vec<ParsedCommand> = Vec::new();
|
||||
let mut cwd: Option<String> = None;
|
||||
for tokens in &parts {
|
||||
if let Some((head, tail)) = tokens.split_first()
|
||||
&& head == "cd"
|
||||
{
|
||||
if let Some(dir) = tail.first() {
|
||||
cwd = Some(match &cwd {
|
||||
Some(base) => join_paths(base, dir),
|
||||
None => dir.clone(),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let parsed = summarize_main_tokens(tokens);
|
||||
let parsed = match parsed {
|
||||
ParsedCommand::Read { cmd, name, path } => {
|
||||
if let Some(base) = &cwd {
|
||||
let full = join_paths(base, &path.to_string_lossy());
|
||||
ParsedCommand::Read {
|
||||
cmd,
|
||||
name,
|
||||
path: PathBuf::from(full),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read { cmd, name, path }
|
||||
}
|
||||
}
|
||||
other => other,
|
||||
};
|
||||
commands.push(parsed);
|
||||
}
|
||||
|
||||
while let Some(next) = simplify_once(&commands) {
|
||||
commands = next;
|
||||
@@ -1164,10 +1175,39 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
cmd: script.clone(),
|
||||
}]);
|
||||
}
|
||||
let mut commands: Vec<ParsedCommand> = filtered_commands
|
||||
.into_iter()
|
||||
.map(|tokens| summarize_main_tokens(&tokens))
|
||||
.collect();
|
||||
// Build parsed commands, tracking `cd` segments to compute effective file paths.
|
||||
let mut commands: Vec<ParsedCommand> = Vec::new();
|
||||
let mut cwd: Option<String> = None;
|
||||
for tokens in filtered_commands.into_iter() {
|
||||
if let Some((head, tail)) = tokens.split_first()
|
||||
&& head == "cd"
|
||||
{
|
||||
if let Some(dir) = tail.first() {
|
||||
cwd = Some(match &cwd {
|
||||
Some(base) => join_paths(base, dir),
|
||||
None => dir.clone(),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let parsed = summarize_main_tokens(&tokens);
|
||||
let parsed = match parsed {
|
||||
ParsedCommand::Read { cmd, name, path } => {
|
||||
if let Some(base) = &cwd {
|
||||
let full = join_paths(base, &path.to_string_lossy());
|
||||
ParsedCommand::Read {
|
||||
cmd,
|
||||
name,
|
||||
path: PathBuf::from(full),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read { cmd, name, path }
|
||||
}
|
||||
}
|
||||
other => other,
|
||||
};
|
||||
commands.push(parsed);
|
||||
}
|
||||
if commands.len() > 1 {
|
||||
commands.retain(|pc| !matches!(pc, ParsedCommand::Unknown { cmd } if cmd == "true"));
|
||||
// Apply the same simplifications used for non-bash parsing, e.g., drop leading `cd`.
|
||||
@@ -1187,7 +1227,7 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
commands = commands
|
||||
.into_iter()
|
||||
.map(|pc| match pc {
|
||||
ParsedCommand::Read { name, cmd, .. } => {
|
||||
ParsedCommand::Read { name, cmd, path } => {
|
||||
if had_connectors {
|
||||
let has_pipe = script_tokens.iter().any(|t| t == "|");
|
||||
let has_sed_n = script_tokens.windows(2).any(|w| {
|
||||
@@ -1198,14 +1238,16 @@ fn parse_bash_lc_commands(original: &[String]) -> Option<Vec<ParsedCommand>> {
|
||||
ParsedCommand::Read {
|
||||
cmd: script.clone(),
|
||||
name,
|
||||
path,
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read { cmd, name }
|
||||
ParsedCommand::Read { cmd, name, path }
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Read {
|
||||
cmd: shlex_join(&script_tokens),
|
||||
name,
|
||||
path,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1370,10 +1412,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
tail
|
||||
};
|
||||
if effective_tail.len() == 1 {
|
||||
let name = short_display_path(&effective_tail[0]);
|
||||
let path = effective_tail[0].clone();
|
||||
let name = short_display_path(&path);
|
||||
ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Unknown {
|
||||
@@ -1408,10 +1452,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
i += 1;
|
||||
}
|
||||
if let Some(p) = candidates.into_iter().find(|p| !p.starts_with('-')) {
|
||||
let name = short_display_path(p);
|
||||
let path = p.clone();
|
||||
let name = short_display_path(&path);
|
||||
return ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1450,10 +1496,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
i += 1;
|
||||
}
|
||||
if let Some(p) = candidates.into_iter().find(|p| !p.starts_with('-')) {
|
||||
let name = short_display_path(p);
|
||||
let path = p.clone();
|
||||
let name = short_display_path(&path);
|
||||
return ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1465,10 +1513,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
// Avoid treating option values as paths (e.g., nl -s " ").
|
||||
let candidates = skip_flag_values(tail, &["-s", "-w", "-v", "-i", "-b"]);
|
||||
if let Some(p) = candidates.into_iter().find(|p| !p.starts_with('-')) {
|
||||
let name = short_display_path(p);
|
||||
let path = p.clone();
|
||||
let name = short_display_path(&path);
|
||||
ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Unknown {
|
||||
@@ -1483,10 +1533,12 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
&& is_valid_sed_n_arg(tail.get(1).map(String::as_str)) =>
|
||||
{
|
||||
if let Some(path) = tail.get(2) {
|
||||
let name = short_display_path(path);
|
||||
let path = path.clone();
|
||||
let name = short_display_path(&path);
|
||||
ParsedCommand::Read {
|
||||
cmd: shlex_join(main_cmd),
|
||||
name,
|
||||
path: PathBuf::from(path),
|
||||
}
|
||||
} else {
|
||||
ParsedCommand::Unknown {
|
||||
@@ -1500,3 +1552,30 @@ fn summarize_main_tokens(main_cmd: &[String]) -> ParsedCommand {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn is_abs_like(path: &str) -> bool {
|
||||
if std::path::Path::new(path).is_absolute() {
|
||||
return true;
|
||||
}
|
||||
let mut chars = path.chars();
|
||||
match (chars.next(), chars.next(), chars.next()) {
|
||||
// Windows drive path like C:\
|
||||
(Some(d), Some(':'), Some('\\')) if d.is_ascii_alphabetic() => return true,
|
||||
// UNC path like \\server\share
|
||||
(Some('\\'), Some('\\'), _) => return true,
|
||||
_ => {}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn join_paths(base: &str, rel: &str) -> String {
|
||||
if is_abs_like(rel) {
|
||||
return rel.to_string();
|
||||
}
|
||||
if base.is_empty() {
|
||||
return rel.to_string();
|
||||
}
|
||||
let mut buf = PathBuf::from(base);
|
||||
buf.push(rel);
|
||||
buf.to_string_lossy().to_string()
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ use tracing::error;
|
||||
|
||||
/// Default filename scanned for project-level docs.
|
||||
pub const DEFAULT_PROJECT_DOC_FILENAME: &str = "AGENTS.md";
|
||||
/// Preferred local override for project-level docs.
|
||||
pub const LOCAL_PROJECT_DOC_FILENAME: &str = "AGENTS.override.md";
|
||||
|
||||
/// When both `Config::instructions` and the project doc are present, they will
|
||||
/// be concatenated with the following separator.
|
||||
@@ -178,7 +180,8 @@ pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBu
|
||||
|
||||
fn candidate_filenames<'a>(config: &'a Config) -> Vec<&'a str> {
|
||||
let mut names: Vec<&'a str> =
|
||||
Vec::with_capacity(1 + config.project_doc_fallback_filenames.len());
|
||||
Vec::with_capacity(2 + config.project_doc_fallback_filenames.len());
|
||||
names.push(LOCAL_PROJECT_DOC_FILENAME);
|
||||
names.push(DEFAULT_PROJECT_DOC_FILENAME);
|
||||
for candidate in &config.project_doc_fallback_filenames {
|
||||
let candidate = candidate.as_str();
|
||||
@@ -381,6 +384,29 @@ mod tests {
|
||||
assert_eq!(res, "root doc\n\ncrate doc");
|
||||
}
|
||||
|
||||
/// AGENTS.override.md is preferred over AGENTS.md when both are present.
|
||||
#[tokio::test]
|
||||
async fn agents_local_md_preferred() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
fs::write(tmp.path().join(DEFAULT_PROJECT_DOC_FILENAME), "versioned").unwrap();
|
||||
fs::write(tmp.path().join(LOCAL_PROJECT_DOC_FILENAME), "local").unwrap();
|
||||
|
||||
let cfg = make_config(&tmp, 4096, None);
|
||||
|
||||
let res = get_user_instructions(&cfg)
|
||||
.await
|
||||
.expect("local doc expected");
|
||||
|
||||
assert_eq!(res, "local");
|
||||
|
||||
let discovery = discover_project_doc_paths(&cfg).expect("discover paths");
|
||||
assert_eq!(discovery.len(), 1);
|
||||
assert_eq!(
|
||||
discovery[0].file_name().unwrap().to_string_lossy(),
|
||||
LOCAL_PROJECT_DOC_FILENAME
|
||||
);
|
||||
}
|
||||
|
||||
/// When AGENTS.md is absent but a configured fallback exists, the fallback is used.
|
||||
#[tokio::test]
|
||||
async fn uses_configured_fallback_when_agents_missing() {
|
||||
|
||||
773
codex-rs/core/src/tools/handlers/mcp_resource.rs
Normal file
773
codex-rs/core/src/tools/handlers/mcp_resource.rs
Normal file
@@ -0,0 +1,773 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use mcp_types::Resource;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use mcp_types::TextContent;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::McpInvocation;
|
||||
use crate::protocol::McpToolCallBeginEvent;
|
||||
use crate::protocol::McpToolCallEndEvent;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct McpResourceHandler;
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct ListResourcesArgs {
|
||||
/// Lists all resources from all servers if not specified.
|
||||
#[serde(default)]
|
||||
server: Option<String>,
|
||||
#[serde(default)]
|
||||
cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct ListResourceTemplatesArgs {
|
||||
/// Lists all resource templates from all servers if not specified.
|
||||
#[serde(default)]
|
||||
server: Option<String>,
|
||||
#[serde(default)]
|
||||
cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ReadResourceArgs {
|
||||
server: String,
|
||||
uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResourceWithServer {
|
||||
server: String,
|
||||
#[serde(flatten)]
|
||||
resource: Resource,
|
||||
}
|
||||
|
||||
impl ResourceWithServer {
|
||||
fn new(server: String, resource: Resource) -> Self {
|
||||
Self { server, resource }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResourceTemplateWithServer {
|
||||
server: String,
|
||||
#[serde(flatten)]
|
||||
template: ResourceTemplate,
|
||||
}
|
||||
|
||||
impl ResourceTemplateWithServer {
|
||||
fn new(server: String, template: ResourceTemplate) -> Self {
|
||||
Self { server, template }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct ListResourcesPayload {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
server: Option<String>,
|
||||
resources: Vec<ResourceWithServer>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
impl ListResourcesPayload {
|
||||
fn from_single_server(server: String, result: ListResourcesResult) -> Self {
|
||||
let resources = result
|
||||
.resources
|
||||
.into_iter()
|
||||
.map(|resource| ResourceWithServer::new(server.clone(), resource))
|
||||
.collect();
|
||||
Self {
|
||||
server: Some(server),
|
||||
resources,
|
||||
next_cursor: result.next_cursor,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_all_servers(resources_by_server: HashMap<String, Vec<Resource>>) -> Self {
|
||||
let mut entries: Vec<(String, Vec<Resource>)> = resources_by_server.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut resources = Vec::new();
|
||||
for (server, server_resources) in entries {
|
||||
for resource in server_resources {
|
||||
resources.push(ResourceWithServer::new(server.clone(), resource));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
server: None,
|
||||
resources,
|
||||
next_cursor: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct ListResourceTemplatesPayload {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
server: Option<String>,
|
||||
resource_templates: Vec<ResourceTemplateWithServer>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
impl ListResourceTemplatesPayload {
|
||||
fn from_single_server(server: String, result: ListResourceTemplatesResult) -> Self {
|
||||
let resource_templates = result
|
||||
.resource_templates
|
||||
.into_iter()
|
||||
.map(|template| ResourceTemplateWithServer::new(server.clone(), template))
|
||||
.collect();
|
||||
Self {
|
||||
server: Some(server),
|
||||
resource_templates,
|
||||
next_cursor: result.next_cursor,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_all_servers(templates_by_server: HashMap<String, Vec<ResourceTemplate>>) -> Self {
|
||||
let mut entries: Vec<(String, Vec<ResourceTemplate>)> =
|
||||
templates_by_server.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut resource_templates = Vec::new();
|
||||
for (server, server_templates) in entries {
|
||||
for template in server_templates {
|
||||
resource_templates.push(ResourceTemplateWithServer::new(server.clone(), template));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
server: None,
|
||||
resource_templates,
|
||||
next_cursor: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ReadResourcePayload {
|
||||
server: String,
|
||||
uri: String,
|
||||
#[serde(flatten)]
|
||||
result: ReadResourceResult,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for McpResourceHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"mcp_resource handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let arguments_value = parse_arguments(arguments.as_str())?;
|
||||
|
||||
match tool_name.as_str() {
|
||||
"list_mcp_resources" => {
|
||||
handle_list_resources(
|
||||
Arc::clone(&session),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
arguments_value.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
"list_mcp_resource_templates" => {
|
||||
handle_list_resource_templates(
|
||||
Arc::clone(&session),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
arguments_value.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
"read_mcp_resource" => {
|
||||
handle_read_resource(Arc::clone(&session), sub_id, call_id, arguments_value).await
|
||||
}
|
||||
other => Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported MCP resource tool: {other}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_resources(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ListResourcesArgs = parse_args_with_default(arguments.clone())?;
|
||||
let ListResourcesArgs { server, cursor } = args;
|
||||
let server = normalize_optional_string(server);
|
||||
let cursor = normalize_optional_string(cursor);
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone().unwrap_or_else(|| "codex".to_string()),
|
||||
tool: "list_mcp_resources".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ListResourcesPayload, FunctionCallError> = async {
|
||||
if let Some(server_name) = server.clone() {
|
||||
let params = cursor.clone().map(|value| ListResourcesRequestParams {
|
||||
cursor: Some(value),
|
||||
});
|
||||
let result = session
|
||||
.list_resources(&server_name, params)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("resources/list failed: {err:#}"))
|
||||
})?;
|
||||
Ok(ListResourcesPayload::from_single_server(
|
||||
server_name,
|
||||
result,
|
||||
))
|
||||
} else {
|
||||
if cursor.is_some() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"cursor can only be used when a server is specified".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let resources = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.list_all_resources()
|
||||
.await;
|
||||
Ok(ListResourcesPayload::from_all_servers(resources))
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function { content, success } = &output else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_resource_templates(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ListResourceTemplatesArgs = parse_args_with_default(arguments.clone())?;
|
||||
let ListResourceTemplatesArgs { server, cursor } = args;
|
||||
let server = normalize_optional_string(server);
|
||||
let cursor = normalize_optional_string(cursor);
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone().unwrap_or_else(|| "codex".to_string()),
|
||||
tool: "list_mcp_resource_templates".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ListResourceTemplatesPayload, FunctionCallError> = async {
|
||||
if let Some(server_name) = server.clone() {
|
||||
let params = cursor
|
||||
.clone()
|
||||
.map(|value| ListResourceTemplatesRequestParams {
|
||||
cursor: Some(value),
|
||||
});
|
||||
let result = session
|
||||
.list_resource_templates(&server_name, params)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"resources/templates/list failed: {err:#}"
|
||||
))
|
||||
})?;
|
||||
Ok(ListResourceTemplatesPayload::from_single_server(
|
||||
server_name,
|
||||
result,
|
||||
))
|
||||
} else {
|
||||
if cursor.is_some() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"cursor can only be used when a server is specified".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let templates = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.list_all_resource_templates()
|
||||
.await;
|
||||
Ok(ListResourceTemplatesPayload::from_all_servers(templates))
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function { content, success } = &output else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_read_resource(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ReadResourceArgs = parse_args(arguments.clone())?;
|
||||
let ReadResourceArgs { server, uri } = args;
|
||||
let server = normalize_required_string("server", server)?;
|
||||
let uri = normalize_required_string("uri", uri)?;
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone(),
|
||||
tool: "read_mcp_resource".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ReadResourcePayload, FunctionCallError> = async {
|
||||
let result = session
|
||||
.read_resource(&server, ReadResourceRequestParams { uri: uri.clone() })
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("resources/read failed: {err:#}"))
|
||||
})?;
|
||||
|
||||
Ok(ReadResourcePayload {
|
||||
server,
|
||||
uri,
|
||||
result,
|
||||
})
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function { content, success } = &output else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn call_tool_result_from_content(content: &str, success: Option<bool>) -> CallToolResult {
|
||||
CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
annotations: None,
|
||||
text: content.to_string(),
|
||||
r#type: "text".to_string(),
|
||||
})],
|
||||
is_error: success.map(|value| !value),
|
||||
structured_content: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_tool_call_begin(
|
||||
session: &Arc<Session>,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
invocation: McpInvocation,
|
||||
) {
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id: call_id.to_string(),
|
||||
invocation,
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn emit_tool_call_end(
|
||||
session: &Arc<Session>,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
invocation: McpInvocation,
|
||||
duration: Duration,
|
||||
result: Result<CallToolResult, String>,
|
||||
) {
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||
call_id: call_id.to_string(),
|
||||
invocation,
|
||||
duration,
|
||||
result,
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
fn normalize_optional_string(input: Option<String>) -> Option<String> {
|
||||
input.and_then(|value| {
|
||||
let trimmed = value.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn normalize_required_string(field: &str, value: String) -> Result<String, FunctionCallError> {
|
||||
match normalize_optional_string(Some(value)) {
|
||||
Some(normalized) => Ok(normalized),
|
||||
None => Err(FunctionCallError::RespondToModel(format!(
|
||||
"{field} must be provided"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_function_output<T>(payload: T) -> Result<ToolOutput, FunctionCallError>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let content = serde_json::to_string(&payload).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to serialize MCP resource response: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_arguments(raw_args: &str) -> Result<Option<Value>, FunctionCallError> {
|
||||
if raw_args.trim().is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
serde_json::from_str(raw_args).map(Some).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args<T>(arguments: Option<Value>) -> Result<T, FunctionCallError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
match arguments {
|
||||
Some(value) => serde_json::from_value(value).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}"))
|
||||
}),
|
||||
None => Err(FunctionCallError::RespondToModel(
|
||||
"failed to parse function arguments: expected value".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args_with_default<T>(arguments: Option<Value>) -> Result<T, FunctionCallError>
|
||||
where
|
||||
T: DeserializeOwned + Default,
|
||||
{
|
||||
match arguments {
|
||||
Some(value) => parse_args(Some(value)),
|
||||
None => Ok(T::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
fn resource(uri: &str, name: &str) -> Resource {
|
||||
Resource {
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
name: name.to_string(),
|
||||
size: None,
|
||||
title: None,
|
||||
uri: uri.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn template(uri_template: &str, name: &str) -> ResourceTemplate {
|
||||
ResourceTemplate {
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
name: name.to_string(),
|
||||
title: None,
|
||||
uri_template: uri_template.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resource_with_server_serializes_server_field() {
|
||||
let entry = ResourceWithServer::new("test".to_string(), resource("memo://id", "memo"));
|
||||
let value = serde_json::to_value(&entry).expect("serialize resource");
|
||||
|
||||
assert_eq!(value["server"], json!("test"));
|
||||
assert_eq!(value["uri"], json!("memo://id"));
|
||||
assert_eq!(value["name"], json!("memo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_resources_payload_from_single_server_copies_next_cursor() {
|
||||
let result = ListResourcesResult {
|
||||
next_cursor: Some("cursor-1".to_string()),
|
||||
resources: vec![resource("memo://id", "memo")],
|
||||
};
|
||||
let payload = ListResourcesPayload::from_single_server("srv".to_string(), result);
|
||||
let value = serde_json::to_value(&payload).expect("serialize payload");
|
||||
|
||||
assert_eq!(value["server"], json!("srv"));
|
||||
assert_eq!(value["nextCursor"], json!("cursor-1"));
|
||||
let resources = value["resources"].as_array().expect("resources array");
|
||||
assert_eq!(resources.len(), 1);
|
||||
assert_eq!(resources[0]["server"], json!("srv"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_resources_payload_from_all_servers_is_sorted() {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("beta".to_string(), vec![resource("memo://b-1", "b-1")]);
|
||||
map.insert(
|
||||
"alpha".to_string(),
|
||||
vec![resource("memo://a-1", "a-1"), resource("memo://a-2", "a-2")],
|
||||
);
|
||||
|
||||
let payload = ListResourcesPayload::from_all_servers(map);
|
||||
let value = serde_json::to_value(&payload).expect("serialize payload");
|
||||
let uris: Vec<String> = value["resources"]
|
||||
.as_array()
|
||||
.expect("resources array")
|
||||
.iter()
|
||||
.map(|entry| entry["uri"].as_str().unwrap().to_string())
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
uris,
|
||||
vec![
|
||||
"memo://a-1".to_string(),
|
||||
"memo://a-2".to_string(),
|
||||
"memo://b-1".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_tool_result_from_content_marks_success() {
|
||||
let result = call_tool_result_from_content("{}", Some(true));
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert_eq!(result.content.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_arguments_handles_empty_and_json() {
|
||||
assert!(
|
||||
parse_arguments(" \n\t").unwrap().is_none(),
|
||||
"expected None for empty arguments"
|
||||
);
|
||||
|
||||
let value = parse_arguments(r#"{"server":"figma"}"#)
|
||||
.expect("parse json")
|
||||
.expect("value present");
|
||||
assert_eq!(value["server"], json!("figma"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn template_with_server_serializes_server_field() {
|
||||
let entry =
|
||||
ResourceTemplateWithServer::new("srv".to_string(), template("memo://{id}", "memo"));
|
||||
let value = serde_json::to_value(&entry).expect("serialize template");
|
||||
|
||||
assert_eq!(
|
||||
value,
|
||||
json!({
|
||||
"server": "srv",
|
||||
"uriTemplate": "memo://{id}",
|
||||
"name": "memo"
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ mod exec_stream;
|
||||
mod grep_files;
|
||||
mod list_dir;
|
||||
mod mcp;
|
||||
mod mcp_resource;
|
||||
mod plan;
|
||||
mod read_file;
|
||||
mod shell;
|
||||
@@ -17,6 +18,7 @@ pub use exec_stream::ExecStreamHandler;
|
||||
pub use grep_files::GrepFilesHandler;
|
||||
pub use list_dir::ListDirHandler;
|
||||
pub use mcp::McpHandler;
|
||||
pub use mcp_resource::McpResourceHandler;
|
||||
pub use plan::PlanHandler;
|
||||
pub use read_file::ReadFileHandler;
|
||||
pub use shell::ShellHandler;
|
||||
|
||||
@@ -238,6 +238,7 @@ fn truncate_function_error(err: FunctionCallError) -> FunctionCallError {
|
||||
FunctionCallError::RespondToModel(msg) => {
|
||||
FunctionCallError::RespondToModel(format_exec_output(&msg))
|
||||
}
|
||||
FunctionCallError::Denied(msg) => FunctionCallError::Denied(format_exec_output(&msg)),
|
||||
FunctionCallError::Fatal(msg) => FunctionCallError::Fatal(format_exec_output(&msg)),
|
||||
other => other,
|
||||
}
|
||||
|
||||
@@ -511,6 +511,107 @@ fn create_list_dir_tool() -> ToolSpec {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_list_mcp_resources_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"server".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Optional MCP server name. When omitted, lists resources from every configured server."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"cursor".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Opaque cursor returned by a previous list_mcp_resources call for the same server."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "list_mcp_resources".to_string(),
|
||||
description: "Lists resources provided by MCP servers. Resources allow servers to share data that provides context to language models, such as files, database schemas, or application-specific information. Prefer resources over web search when possible.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_list_mcp_resource_templates_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"server".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Optional MCP server name. When omitted, lists resource templates from all configured servers."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"cursor".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Opaque cursor returned by a previous list_mcp_resource_templates call for the same server."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "list_mcp_resource_templates".to_string(),
|
||||
description: "Lists resource templates provided by MCP servers. Parameterized resource templates allow servers to share data that takes parameters and provides context to language models, such as files, database schemas, or application-specific information. Prefer resource templates over web search when possible.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_read_mcp_resource_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"server".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"MCP server name exactly as configured. Must match the 'server' field returned by list_mcp_resources."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"uri".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Resource URI to read. Must be one of the URIs returned by list_mcp_resources."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "read_mcp_resource".to_string(),
|
||||
description:
|
||||
"Read a specific resource from an MCP server given the server name and resource URI."
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["server".to_string(), "uri".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
/// TODO(dylan): deprecate once we get rid of json tool
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct ApplyPatchToolArgs {
|
||||
@@ -723,6 +824,7 @@ pub(crate) fn build_specs(
|
||||
use crate::tools::handlers::GrepFilesHandler;
|
||||
use crate::tools::handlers::ListDirHandler;
|
||||
use crate::tools::handlers::McpHandler;
|
||||
use crate::tools::handlers::McpResourceHandler;
|
||||
use crate::tools::handlers::PlanHandler;
|
||||
use crate::tools::handlers::ReadFileHandler;
|
||||
use crate::tools::handlers::ShellHandler;
|
||||
@@ -740,6 +842,7 @@ pub(crate) fn build_specs(
|
||||
let apply_patch_handler = Arc::new(ApplyPatchHandler);
|
||||
let view_image_handler = Arc::new(ViewImageHandler);
|
||||
let mcp_handler = Arc::new(McpHandler);
|
||||
let mcp_resource_handler = Arc::new(McpResourceHandler);
|
||||
|
||||
if config.experimental_unified_exec_tool {
|
||||
builder.push_spec(create_unified_exec_tool());
|
||||
@@ -770,6 +873,13 @@ pub(crate) fn build_specs(
|
||||
builder.register_handler("container.exec", shell_handler.clone());
|
||||
builder.register_handler("local_shell", shell_handler);
|
||||
|
||||
builder.push_spec_with_parallel_support(create_list_mcp_resources_tool(), true);
|
||||
builder.push_spec_with_parallel_support(create_list_mcp_resource_templates_tool(), true);
|
||||
builder.push_spec_with_parallel_support(create_read_mcp_resource_tool(), true);
|
||||
builder.register_handler("list_mcp_resources", mcp_resource_handler.clone());
|
||||
builder.register_handler("list_mcp_resource_templates", mcp_resource_handler.clone());
|
||||
builder.register_handler("read_mcp_resource", mcp_resource_handler);
|
||||
|
||||
if config.plan_tool {
|
||||
builder.push_spec(PLAN_TOOL.clone());
|
||||
builder.register_handler("update_plan", plan_handler);
|
||||
@@ -917,7 +1027,15 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"update_plan",
|
||||
"web_search",
|
||||
"view_image",
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -936,7 +1054,15 @@ mod tests {
|
||||
|
||||
assert_eq_tool_names(
|
||||
&tools,
|
||||
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"update_plan",
|
||||
"web_search",
|
||||
"view_image",
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1043,15 +1169,19 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"web_search",
|
||||
"view_image",
|
||||
"test_server/do_something_cool",
|
||||
],
|
||||
);
|
||||
|
||||
let tool = find_tool(&tools, "test_server/do_something_cool");
|
||||
assert_eq!(
|
||||
tools[3].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
&tool.spec,
|
||||
&ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
properties: BTreeMap::from([
|
||||
@@ -1158,6 +1288,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"view_image",
|
||||
"test_server/cool",
|
||||
"test_server/do",
|
||||
@@ -1206,6 +1339,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1214,7 +1350,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/search".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1271,6 +1407,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1278,7 +1417,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/paginate".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1334,6 +1473,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1341,7 +1483,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/tags".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1399,6 +1541,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1406,7 +1551,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/value".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1501,6 +1646,9 @@ mod tests {
|
||||
&tools,
|
||||
&[
|
||||
"unified_exec",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"apply_patch",
|
||||
"web_search",
|
||||
"view_image",
|
||||
@@ -1509,7 +1657,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4].spec,
|
||||
tools[7].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
|
||||
@@ -79,6 +79,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -78,6 +78,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -63,6 +63,7 @@ async fn responses_stream_includes_task_type_header() {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
@@ -657,6 +657,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
config.model.as_str(),
|
||||
config.model_family.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
Some(AuthMode::ChatGPT),
|
||||
false,
|
||||
"test".to_string(),
|
||||
|
||||
171
codex-rs/core/tests/suite/cross_session.rs
Normal file
171
codex-rs/core/tests/suite/cross_session.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::CrossSessionSpawnParams;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::cross_session::AssistantMessage;
|
||||
use codex_core::cross_session::CrossSessionHub;
|
||||
use codex_core::cross_session::PostUserTurnRequest;
|
||||
use codex_core::cross_session::RoleOrId;
|
||||
use codex_core::cross_session::SessionEventStream;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use tempfile::TempDir;
|
||||
use tokio_stream::StreamExt;
|
||||
use wiremock::MockServer;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn cross_session_hub_routes_between_roles() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let bodies = vec![
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-1"),
|
||||
responses::ev_assistant_message("solver-msg-1", "Need direction"),
|
||||
responses::ev_completed("solver-resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("director-resp-1"),
|
||||
responses::ev_assistant_message("director-msg-1", "Proceed iteratively"),
|
||||
responses::ev_completed("director-resp-1"),
|
||||
]),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("solver-resp-2"),
|
||||
responses::ev_assistant_message("solver-msg-2", "Acknowledged"),
|
||||
responses::ev_completed("solver-resp-2"),
|
||||
]),
|
||||
];
|
||||
let response_mock = responses::mount_sse_sequence(&server, bodies).await;
|
||||
|
||||
let hub = Arc::new(CrossSessionHub::new());
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy-key"));
|
||||
let run_id = "run-cross-session".to_string();
|
||||
|
||||
let solver_config = build_config(&server).await?;
|
||||
let solver = conversation_manager
|
||||
.new_conversation_with_cross_session(
|
||||
solver_config,
|
||||
CrossSessionSpawnParams {
|
||||
hub: Arc::clone(&hub),
|
||||
run_id: Some(run_id.clone()),
|
||||
role: Some("solver".to_string()),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let director_config = build_config(&server).await?;
|
||||
let director = conversation_manager
|
||||
.new_conversation_with_cross_session(
|
||||
director_config,
|
||||
CrossSessionSpawnParams {
|
||||
hub: Arc::clone(&hub),
|
||||
run_id: Some(run_id.clone()),
|
||||
role: Some("director".to_string()),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut solver_events = hub.stream_events(solver.conversation_id)?;
|
||||
let mut director_events = hub.stream_events(director.conversation_id)?;
|
||||
|
||||
let solver_handle = hub
|
||||
.post_user_turn(PostUserTurnRequest {
|
||||
target: RoleOrId::RunRole {
|
||||
run_id: run_id.clone(),
|
||||
role: "solver".to_string(),
|
||||
},
|
||||
text: "kick off plan".to_string(),
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
let solver_first = expect_message(&hub, &solver_handle, "Need direction").await?;
|
||||
|
||||
let director_handle = hub
|
||||
.post_user_turn(PostUserTurnRequest {
|
||||
target: RoleOrId::RunRole {
|
||||
run_id: run_id.clone(),
|
||||
role: "director".to_string(),
|
||||
},
|
||||
text: solver_first.message.message.clone(),
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
let director_first = expect_message(&hub, &director_handle, "Proceed iteratively").await?;
|
||||
|
||||
let solver_followup = hub
|
||||
.post_user_turn(PostUserTurnRequest {
|
||||
target: RoleOrId::Session(solver.conversation_id),
|
||||
text: director_first.message.message.clone(),
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await?;
|
||||
let solver_reply = expect_message(&hub, &solver_followup, "Acknowledged").await?;
|
||||
|
||||
let solver_event = expect_agent_event(&mut solver_events).await;
|
||||
match solver_event {
|
||||
EventMsg::AgentMessage(msg) => assert_eq!(msg.message, "Need direction"),
|
||||
_ => panic!("expected solver agent message"),
|
||||
}
|
||||
|
||||
let director_event = expect_agent_event(&mut director_events).await;
|
||||
match director_event {
|
||||
EventMsg::AgentMessage(msg) => assert_eq!(msg.message, "Proceed iteratively"),
|
||||
_ => panic!("expected director agent message"),
|
||||
}
|
||||
|
||||
assert_eq!(solver_first.message.message, "Need direction");
|
||||
assert_eq!(director_first.message.message, "Proceed iteratively");
|
||||
assert_eq!(solver_reply.message.message, "Acknowledged");
|
||||
assert_eq!(response_mock.requests().len(), 3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn build_config(server: &MockServer) -> anyhow::Result<Config> {
|
||||
let home = TempDir::new()?;
|
||||
let cwd = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
let mut provider = built_in_model_providers()["openai"].clone();
|
||||
provider.base_url = Some(format!("{}/v1", server.uri()));
|
||||
config.model_provider = provider;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
async fn expect_message(
|
||||
hub: &CrossSessionHub,
|
||||
handle: &codex_core::cross_session::TurnHandle,
|
||||
expected: &str,
|
||||
) -> anyhow::Result<AssistantMessage> {
|
||||
let message = hub
|
||||
.await_first_assistant(handle, Duration::from_secs(1))
|
||||
.await?;
|
||||
assert_eq!(message.message.message, expected);
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
async fn expect_agent_event(stream: &mut SessionEventStream) -> EventMsg {
|
||||
loop {
|
||||
let maybe_event = match tokio::time::timeout(Duration::from_secs(1), stream.next()).await {
|
||||
Ok(event) => event,
|
||||
Err(_) => panic!("event timeout"),
|
||||
};
|
||||
|
||||
if let Some(event) = maybe_event {
|
||||
let msg = event.event.msg;
|
||||
if matches!(msg, EventMsg::AgentMessage(_)) {
|
||||
return msg;
|
||||
}
|
||||
} else {
|
||||
panic!("stream ended before agent message");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ mod cli_stream;
|
||||
mod client;
|
||||
mod compact;
|
||||
mod compact_resume_fork;
|
||||
mod cross_session;
|
||||
mod exec;
|
||||
mod exec_stream_events;
|
||||
mod fork_conversation;
|
||||
|
||||
@@ -94,21 +94,37 @@ async fn model_selects_expected_tools() {
|
||||
let codex_tools = collect_tool_identifiers_for_model("codex-mini-latest").await;
|
||||
assert_eq!(
|
||||
codex_tools,
|
||||
vec!["local_shell".to_string()],
|
||||
vec![
|
||||
"local_shell".to_string(),
|
||||
"list_mcp_resources".to_string(),
|
||||
"list_mcp_resource_templates".to_string(),
|
||||
"read_mcp_resource".to_string()
|
||||
],
|
||||
"codex-mini-latest should expose the local shell tool",
|
||||
);
|
||||
|
||||
let o3_tools = collect_tool_identifiers_for_model("o3").await;
|
||||
assert_eq!(
|
||||
o3_tools,
|
||||
vec!["shell".to_string()],
|
||||
vec![
|
||||
"shell".to_string(),
|
||||
"list_mcp_resources".to_string(),
|
||||
"list_mcp_resource_templates".to_string(),
|
||||
"read_mcp_resource".to_string()
|
||||
],
|
||||
"o3 should expose the generic shell tool",
|
||||
);
|
||||
|
||||
let gpt5_codex_tools = collect_tool_identifiers_for_model("gpt-5-codex").await;
|
||||
assert_eq!(
|
||||
gpt5_codex_tools,
|
||||
vec!["shell".to_string(), "apply_patch".to_string(),],
|
||||
vec![
|
||||
"shell".to_string(),
|
||||
"list_mcp_resources".to_string(),
|
||||
"list_mcp_resource_templates".to_string(),
|
||||
"read_mcp_resource".to_string(),
|
||||
"apply_patch".to_string()
|
||||
],
|
||||
"gpt-5-codex should expose the apply_patch tool",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -223,10 +223,28 @@ async fn prompt_tools_are_consistent_across_requests() {
|
||||
// our internal implementation is responsible for keeping tools in sync
|
||||
// with the OpenAI schema, so we just verify the tool presence here
|
||||
let tools_by_model: HashMap<&'static str, Vec<&'static str>> = HashMap::from([
|
||||
("gpt-5", vec!["shell", "update_plan", "view_image"]),
|
||||
(
|
||||
"gpt-5",
|
||||
vec![
|
||||
"shell",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"update_plan",
|
||||
"view_image",
|
||||
],
|
||||
),
|
||||
(
|
||||
"gpt-5-codex",
|
||||
vec!["shell", "update_plan", "apply_patch", "view_image"],
|
||||
vec![
|
||||
"shell",
|
||||
"list_mcp_resources",
|
||||
"list_mcp_resource_templates",
|
||||
"read_mcp_resource",
|
||||
"update_plan",
|
||||
"apply_patch",
|
||||
"view_image",
|
||||
],
|
||||
),
|
||||
]);
|
||||
let expected_tools_names = tools_by_model
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsStr;
|
||||
use std::ffi::OsString;
|
||||
use std::fs;
|
||||
use std::net::TcpListener;
|
||||
@@ -35,6 +36,7 @@ use tokio::time::sleep;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
#[serial(mcp_test_value)]
|
||||
async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -86,6 +88,8 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
"MCP_TEST_VALUE".to_string(),
|
||||
expected_env_value.to_string(),
|
||||
)])),
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
@@ -106,7 +110,143 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let begin_event = wait_for_event_with_timeout(
|
||||
&fixture.codex,
|
||||
|ev| matches!(ev, EventMsg::McpToolCallBegin(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
let EventMsg::McpToolCallBegin(begin) = begin_event else {
|
||||
unreachable!("event guard guarantees McpToolCallBegin");
|
||||
};
|
||||
assert_eq!(begin.invocation.server, server_name);
|
||||
assert_eq!(begin.invocation.tool, "echo");
|
||||
|
||||
let end_event = wait_for_event(&fixture.codex, |ev| {
|
||||
matches!(ev, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
let EventMsg::McpToolCallEnd(end) = end_event else {
|
||||
unreachable!("event guard guarantees McpToolCallEnd");
|
||||
};
|
||||
|
||||
let result = end
|
||||
.result
|
||||
.as_ref()
|
||||
.expect("rmcp echo tool should return success");
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert!(
|
||||
result.content.is_empty(),
|
||||
"content should default to an empty array"
|
||||
);
|
||||
|
||||
let structured = result
|
||||
.structured_content
|
||||
.as_ref()
|
||||
.expect("structured content");
|
||||
let Value::Object(map) = structured else {
|
||||
panic!("structured content should be an object: {structured:?}");
|
||||
};
|
||||
let echo_value = map
|
||||
.get("echo")
|
||||
.and_then(Value::as_str)
|
||||
.expect("echo payload present");
|
||||
assert_eq!(echo_value, "ECHOING: ping");
|
||||
let env_value = map
|
||||
.get("env")
|
||||
.and_then(Value::as_str)
|
||||
.expect("env snapshot inserted");
|
||||
assert_eq!(env_value, expected_env_value);
|
||||
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
server.verify().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
#[serial(mcp_test_value)]
|
||||
async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let call_id = "call-1234";
|
||||
let server_name = "rmcp_whitelist";
|
||||
let tool_name = format!("{server_name}__echo");
|
||||
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_env_value = "propagated-env-from-whitelist";
|
||||
let _guard = EnvVarGuard::set("MCP_TEST_VALUE", OsStr::new(expected_env_value));
|
||||
let rmcp_test_server_bin = CargoBuild::new()
|
||||
.package("codex-rmcp-client")
|
||||
.bin("test_stdio_server")
|
||||
.run()?
|
||||
.path()
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.features.enable(Feature::RmcpClient);
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: rmcp_test_server_bin,
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: vec!["MCP_TEST_VALUE".to_string()],
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
fixture
|
||||
.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "call the rmcp echo tool".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
@@ -235,6 +375,8 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
@@ -255,7 +397,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
@@ -416,6 +558,8 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
enabled: true,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
@@ -436,7 +580,7 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
sandbox_policy: SandboxPolicy::ReadOnly,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
|
||||
@@ -17,6 +17,7 @@ use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::git_info::get_git_repo_root;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::Event;
|
||||
@@ -168,8 +169,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
model,
|
||||
review_model: None,
|
||||
config_profile,
|
||||
// This CLI is intended to be headless and has no affordances for asking
|
||||
// the user for approval.
|
||||
// Default to never ask for approvals in headless mode. Feature flags can override.
|
||||
approval_policy: Some(AskForApproval::Never),
|
||||
sandbox_mode,
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
@@ -192,6 +192,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides).await?;
|
||||
let approve_all_enabled = config.features.enabled(Feature::ApproveAll);
|
||||
|
||||
let otel = codex_core::otel_init::build_provider(&config, env!("CARGO_PKG_VERSION"));
|
||||
|
||||
@@ -360,6 +361,34 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
if matches!(event.msg, EventMsg::Error(_)) {
|
||||
error_seen = true;
|
||||
}
|
||||
// Auto-approve requests when the approve_all feature is enabled.
|
||||
if approve_all_enabled {
|
||||
match &event.msg {
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
if let Err(e) = conversation
|
||||
.submit(Op::ExecApproval {
|
||||
id: event.id.clone(),
|
||||
decision: codex_core::protocol::ReviewDecision::Approved,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to auto-approve exec: {e}");
|
||||
}
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
if let Err(e) = conversation
|
||||
.submit(Op::PatchApproval {
|
||||
id: event.id.clone(),
|
||||
decision: codex_core::protocol::ReviewDecision::Approved,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to auto-approve patch: {e}");
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let shutdown: CodexStatus = event_processor.process_event(event);
|
||||
match shutdown {
|
||||
CodexStatus::Running => continue,
|
||||
|
||||
81
codex-rs/exec/tests/suite/approve_all.rs
Normal file
81
codex-rs/exec/tests/suite/approve_all.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use anyhow::Result;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex_exec::test_codex_exec;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
async fn run_exec_with_args(args: &[&str]) -> Result<String> {
|
||||
let test = test_codex_exec();
|
||||
|
||||
let call_id = "exec-approve";
|
||||
let exec_args = json!({
|
||||
"command": [
|
||||
if cfg!(windows) { "cmd.exe" } else { "/bin/sh" },
|
||||
if cfg!(windows) { "/C" } else { "-lc" },
|
||||
"echo approve-all-ok",
|
||||
],
|
||||
"timeout_ms": 1500,
|
||||
"with_escalated_permissions": true
|
||||
});
|
||||
|
||||
let response_streams = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "shell", &serde_json::to_string(&exec_args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
];
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let mock = mount_sse_sequence(&server, response_streams).await;
|
||||
|
||||
test.cmd_with_server(&server).args(args).assert().success();
|
||||
|
||||
let requests = mock.requests();
|
||||
assert!(requests.len() >= 2, "expected at least two responses POSTs");
|
||||
let item = requests[1].function_call_output(call_id);
|
||||
let output_str = item
|
||||
.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.expect("function_call_output.output should be a string");
|
||||
|
||||
Ok(output_str.to_string())
|
||||
}
|
||||
|
||||
/// Setting `features.approve_all=true` should switch to auto-approvals.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn approve_all_auto_accepts_exec() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let output = run_exec_with_args(&[
|
||||
"--skip-git-repo-check",
|
||||
"-c",
|
||||
"features.approve_all=true",
|
||||
"train",
|
||||
])
|
||||
.await?;
|
||||
assert!(
|
||||
output.contains("Exit code: 0"),
|
||||
"expected Exit code: 0 in output: {output}"
|
||||
);
|
||||
assert!(
|
||||
output.contains("approve-all-ok"),
|
||||
"expected command output in response: {output}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
// Aggregates all former standalone integration tests as modules.
|
||||
mod apply_patch;
|
||||
mod approve_all;
|
||||
mod auth_env;
|
||||
mod originator;
|
||||
mod output_schema;
|
||||
|
||||
13
codex-rs/feedback/Cargo.toml
Normal file
13
codex-rs/feedback/Cargo.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
name = "codex-feedback"
|
||||
version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
sentry = { version = "0.34" }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
231
codex-rs/feedback/src/lib.rs
Normal file
231
codex-rs/feedback/src/lib.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_protocol::ConversationId;
|
||||
use tracing_subscriber::fmt::writer::MakeWriter;
|
||||
|
||||
const DEFAULT_MAX_BYTES: usize = 2 * 1024 * 1024; // 2 MiB
|
||||
const SENTRY_DSN: &str =
|
||||
"https://ae32ed50620d7a7792c1ce5df38b3e3e@o33249.ingest.us.sentry.io/4510195390611458";
|
||||
const UPLOAD_TIMEOUT_SECS: u64 = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CodexFeedback {
|
||||
inner: Arc<FeedbackInner>,
|
||||
}
|
||||
|
||||
impl Default for CodexFeedback {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexFeedback {
|
||||
pub fn new() -> Self {
|
||||
Self::with_capacity(DEFAULT_MAX_BYTES)
|
||||
}
|
||||
|
||||
pub(crate) fn with_capacity(max_bytes: usize) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(FeedbackInner::new(max_bytes)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_writer(&self) -> FeedbackMakeWriter {
|
||||
FeedbackMakeWriter {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn snapshot(&self, session_id: Option<ConversationId>) -> CodexLogSnapshot {
|
||||
let bytes = {
|
||||
let guard = self.inner.ring.lock().expect("mutex poisoned");
|
||||
guard.snapshot_bytes()
|
||||
};
|
||||
CodexLogSnapshot {
|
||||
bytes,
|
||||
thread_id: session_id
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or("no-active-thread-".to_string() + &ConversationId::new().to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FeedbackInner {
|
||||
ring: Mutex<RingBuffer>,
|
||||
}
|
||||
|
||||
impl FeedbackInner {
|
||||
fn new(max_bytes: usize) -> Self {
|
||||
Self {
|
||||
ring: Mutex::new(RingBuffer::new(max_bytes)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FeedbackMakeWriter {
|
||||
inner: Arc<FeedbackInner>,
|
||||
}
|
||||
|
||||
impl<'a> MakeWriter<'a> for FeedbackMakeWriter {
|
||||
type Writer = FeedbackWriter;
|
||||
|
||||
fn make_writer(&'a self) -> Self::Writer {
|
||||
FeedbackWriter {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FeedbackWriter {
|
||||
inner: Arc<FeedbackInner>,
|
||||
}
|
||||
|
||||
impl Write for FeedbackWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let mut guard = self.inner.ring.lock().map_err(|_| io::ErrorKind::Other)?;
|
||||
guard.push_bytes(buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct RingBuffer {
|
||||
max: usize,
|
||||
buf: VecDeque<u8>,
|
||||
}
|
||||
|
||||
impl RingBuffer {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
max: capacity,
|
||||
buf: VecDeque::with_capacity(capacity),
|
||||
}
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.buf.len()
|
||||
}
|
||||
|
||||
fn push_bytes(&mut self, data: &[u8]) {
|
||||
if data.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the incoming chunk is larger than capacity, keep only the trailing bytes.
|
||||
if data.len() >= self.max {
|
||||
self.buf.clear();
|
||||
let start = data.len() - self.max;
|
||||
self.buf.extend(data[start..].iter().copied());
|
||||
return;
|
||||
}
|
||||
|
||||
// Evict from the front if we would exceed capacity.
|
||||
let needed = self.len() + data.len();
|
||||
if needed > self.max {
|
||||
let to_drop = needed - self.max;
|
||||
for _ in 0..to_drop {
|
||||
let _ = self.buf.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
self.buf.extend(data.iter().copied());
|
||||
}
|
||||
|
||||
fn snapshot_bytes(&self) -> Vec<u8> {
|
||||
self.buf.iter().copied().collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CodexLogSnapshot {
|
||||
bytes: Vec<u8>,
|
||||
pub thread_id: String,
|
||||
}
|
||||
|
||||
impl CodexLogSnapshot {
|
||||
pub(crate) fn as_bytes(&self) -> &[u8] {
|
||||
&self.bytes
|
||||
}
|
||||
|
||||
pub fn save_to_temp_file(&self) -> io::Result<PathBuf> {
|
||||
let dir = std::env::temp_dir();
|
||||
let filename = format!("codex-feedback-{}.log", self.thread_id);
|
||||
let path = dir.join(filename);
|
||||
fs::write(&path, self.as_bytes())?;
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
pub fn upload_to_sentry(&self) -> Result<()> {
|
||||
use std::collections::BTreeMap;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use sentry::Client;
|
||||
use sentry::ClientOptions;
|
||||
use sentry::protocol::Attachment;
|
||||
use sentry::protocol::Envelope;
|
||||
use sentry::protocol::EnvelopeItem;
|
||||
use sentry::protocol::Event;
|
||||
use sentry::protocol::Level;
|
||||
use sentry::transports::DefaultTransportFactory;
|
||||
use sentry::types::Dsn;
|
||||
|
||||
let client = Client::from_config(ClientOptions {
|
||||
dsn: Some(Dsn::from_str(SENTRY_DSN).map_err(|e| anyhow!("invalid DSN: {}", e))?),
|
||||
transport: Some(Arc::new(DefaultTransportFactory {})),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
let tags = BTreeMap::from([(String::from("thread_id"), self.thread_id.to_string())]);
|
||||
|
||||
let event = Event {
|
||||
level: Level::Error,
|
||||
message: Some("Codex Log Upload ".to_string() + &self.thread_id),
|
||||
tags,
|
||||
..Default::default()
|
||||
};
|
||||
let mut envelope = Envelope::new();
|
||||
envelope.add_item(EnvelopeItem::Event(event));
|
||||
envelope.add_item(EnvelopeItem::Attachment(Attachment {
|
||||
buffer: self.bytes.clone(),
|
||||
filename: String::from("codex-logs.log"),
|
||||
content_type: Some("text/plain".to_string()),
|
||||
ty: None,
|
||||
}));
|
||||
|
||||
client.send_envelope(envelope);
|
||||
client.flush(Some(Duration::from_secs(UPLOAD_TIMEOUT_SECS)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_drops_front_when_full() {
|
||||
let fb = CodexFeedback::with_capacity(8);
|
||||
{
|
||||
let mut w = fb.make_writer().make_writer();
|
||||
w.write_all(b"abcdefgh").unwrap();
|
||||
w.write_all(b"ij").unwrap();
|
||||
}
|
||||
let snap = fb.snapshot(None);
|
||||
// Capacity 8: after writing 10 bytes, we should keep the last 8.
|
||||
pretty_assertions::assert_eq!(std::str::from_utf8(snap.as_bytes()).unwrap(), "cdefghij");
|
||||
}
|
||||
}
|
||||
189
codex-rs/infty.md
Normal file
189
codex-rs/infty.md
Normal file
@@ -0,0 +1,189 @@
|
||||
# Codex Infty: Ultra‑Long Task Orchestration
|
||||
|
||||
Design a clean, extensible way to run arbitrarily long tasks (hours–days) with bounded model context, autonomous continuation, and robust correctness review. Works for code and non‑code.
|
||||
|
||||
Status: Proposed • Scope: New crates using `codex-core` • Compatibility: Non‑breaking
|
||||
|
||||
---
|
||||
|
||||
## 1) Motivation
|
||||
- Context windows are limited → we must compact and retrieve.
|
||||
- Models pause/ask for permission → we must self‑direct.
|
||||
- No systematic review → we must verify before returning.
|
||||
|
||||
## 2) Approach (High‑Level)
|
||||
Run three coordinated roles as independent `codex-core` sessions. Reuse existing tools (shell, apply_patch, read_file, list_dir, grep_files) for persistence and retrieval. Add one clean, first-class cross-session facility in core for direction/verification — orchestrator-driven, no model-visible tool. The CLI currently spawns a solver, a director, and three verifiers (`verifier-alpha`, `verifier-beta`, `verifier-gamma`) by default.
|
||||
|
||||
- Solver (Model A): executes plan; writes all results to memory/artifacts; never asks humans to continue.
|
||||
- Director (Model B): answers Solver’s direction questions and re‑plans when needed.
|
||||
- Verifier (Model C…Cₙ): evaluates completion claims; returns pass/fail with structured feedback.
|
||||
|
||||
Inter‑role coordination uses a built‑in CrossSessionHub in core. The orchestrator watches assistant messages and bridges them as user turns to the peer role.
|
||||
|
||||
## 3) Architecture
|
||||
```
|
||||
┌────────────────────────────┐
|
||||
│ codex-infty │
|
||||
│ Orchestrator + CLI/Lib │
|
||||
│ - spawns 3 codex sessions │
|
||||
│ - supervises long runs │
|
||||
│ - configures Run/Role │
|
||||
└────────────┬───────────────┘
|
||||
│
|
||||
┌─────────▼─────────┐ ┌─────────▼─────────┐
|
||||
│ Solver (A) │ │ Director (B) │
|
||||
│ codex-core session│ │ codex-core session│
|
||||
└─────────┬─────────┘ └─────────┬─────────┘
|
||||
│ │
|
||||
└──────────┬──────────────┘
|
||||
│
|
||||
┌───────▼────────┐
|
||||
│ Verifier(s) (C)│
|
||||
│ codex-core sess │
|
||||
└───────┬─────────┘
|
||||
│
|
||||
CrossSessionHub (core, orchestrator‑driven)
|
||||
JSONL rollouts, auto‑compaction
|
||||
```
|
||||
|
||||
### Components
|
||||
- codex-infty (new crate)
|
||||
- Spawns/owns three `codex-core` sessions (A/B/C) with role‑specific base instructions.
|
||||
- Supervises progress over very long runs.
|
||||
- Defines a simple on‑disk Run Store that the models write to using existing tools.
|
||||
- Configures sessions with Run/Role metadata (for cross‑session routing).
|
||||
- codex-core (existing, with one addition)
|
||||
- Reuse streaming, tool routing, JSONL rollouts with resume, auto‑compaction, and existing tools:
|
||||
- `apply_patch`, `shell`/`exec_command`/`write_stdin`
|
||||
- `grep_files`, `read_file`, `list_dir` (enable via model family/experimental tools)
|
||||
- New: built‑in `CrossSessionHub` for intra‑process routing (§5). No new model tool is exposed.
|
||||
|
||||
## 4) Data Model (Durable) and Filesystem Layout
|
||||
Persist everything in a Run Store directory; models read/write using existing tools.
|
||||
|
||||
- Run Store layout (example under `~/.codex/infty/<run-id>/`):
|
||||
- `artifacts/` – blobs and text outputs (models can create via `apply_patch` for text; `shell` for binary moves/copies).
|
||||
- `memory/` – JSON/Markdown notes: facts, hypotheses, plans, decisions, claims, evidence, evaluations.
|
||||
- `index/` – optional search/index artifacts (built out‑of‑band by orchestrator jobs; models can still use `grep_files`).
|
||||
|
||||
Data is append‑only by convention; items link to each other via ids/paths stored in JSON.
|
||||
|
||||
## 5) New Core API: CrossSessionHub (no model tool)
|
||||
Add a core facility that lets the orchestrator bridge assistant messages between sessions by posting user turns.
|
||||
|
||||
### 5.1 Hub API
|
||||
- Registry that maps `{ run_id, role } -> session handle` and `{ session_id } -> session handle`.
|
||||
- Sessions register on spawn with `run_id` and `role`; unregister on drop.
|
||||
- Expose async methods for the orchestrator:
|
||||
- `post_user_turn(to: RoleOrId, text: String) -> TurnHandle` – inject a `UserTurn` as if typed by a user.
|
||||
- `await_first_assistant(turn: &TurnHandle, timeout: Duration) -> AssistantMessage` – wait until the first assistant message for that turn.
|
||||
- `stream_events(session_id) -> impl Stream<Item = Event>` – optional subscription for higher‑level orchestration.
|
||||
|
||||
### 5.2 Orchestrator Bridge Logic
|
||||
- Direction: when the Solver emits an assistant message asking for permission/direction, the orchestrator forwards that assistant text verbatim as a user turn to the Director and waits for the Director’s first assistant reply; it then posts that reply as a user turn to Solver.
|
||||
- Verification: when Solver requests verification, orchestrator forwards request to Verifier(s); structured verdicts (pass/fail/reasons/suggestions) flow back.
|
||||
- Persistence: Each session persists its own events to rollout; the orchestrator just routes.
|
||||
|
||||
## 6) Run Store Facilities
|
||||
- Memory notes follow JSON schemas per role (plans, claims, evidence).
|
||||
- Artifacts include code patches, logs, compiled binaries, docs. Use naming convention `<timestamp>-<summary>.<ext>`.
|
||||
- Orchestrator can create `index/` entries (e.g., embeddings) offline; models still access via standard tools.
|
||||
|
||||
## 7) Orchestrator Flow
|
||||
1. Initialize Run Store + metadata (objective, roles, options).
|
||||
2. Spawn Solver, Director, Verifier sessions via `CrossSessionHub`.
|
||||
3. Seed objective as Solver user turn; monitor outputs.
|
||||
4. Relay direction/verification messages automatically between roles.
|
||||
5. Trigger periodic checkpoints (copy artifacts/memory to dated snapshots).
|
||||
6. On completion, ensure Verifier returns pass, then emit final deliverable path.
|
||||
7. Support resume: reload Run Store, respawn sessions with `InitialHistory::Resumed`.
|
||||
|
||||
## 8) Context Management
|
||||
- Conversational context: rely on `codex-core` auto‑compaction.
|
||||
- Long‑term memory: persist facts/results as files; retrieve with `grep_files`/`read_file`/`list_dir`.
|
||||
- Run Store snapshots allow cold resume even after orchestrator restart.
|
||||
|
||||
## 9) Verification Strategies
|
||||
- Code: tests, linters, type checks via `shell` under sandbox.
|
||||
- Text: grader rubrics, citation/contradiction checks.
|
||||
- Math/research: multi‑verifier consensus, self‑consistency, proof‑sketch validation.
|
||||
|
||||
## 10) Security & Policy
|
||||
- All execution stays under `codex-core` sandbox/approval.
|
||||
- Memory/Artifact tools are pure data I/O (no code execution).
|
||||
- Inter‑role calls run in isolated sessions.
|
||||
|
||||
## 11) MVP (Phased)
|
||||
1. codex-core
|
||||
- Add `CrossSessionHub` with registration and post/await APIs.
|
||||
- Add `run_id` and `role` registration on session spawn (optional fields).
|
||||
- Tests: two sessions in a run; orchestrator posts user text to Director and bridges reply to Solver.
|
||||
2. codex-infty
|
||||
- Orchestrator lib + CLI: create Run Store directories, spawn A/B/C sessions with `run_id`/`role`, run loop; ship role prompts. Enable `grep_files`/`read_file`/`list_dir`.
|
||||
3. Verification
|
||||
- Use `shell` to run checks/tests when applicable; use Verifier sessions for rubric‑based judgments.
|
||||
|
||||
## 12) Finalization & Extensibility
|
||||
- Finalization workflow (after `verdict == pass`): the orchestrator issues a final `UserTurn` to the Solver instructing:
|
||||
- Create a clean `deliverable/` folder under the Run Store.
|
||||
- Copy/transform only the necessary end results; remove scratch artifacts.
|
||||
- Write a `deliverable/README.md` including: overview, contents manifest with paths and sizes, verification steps (how to run tests), and any limitations.
|
||||
- Summarize the work in the final assistant message and return the path to `deliverable/`.
|
||||
|
||||
- Extensibility:
|
||||
- Pluggable `IndexStrategy` (keyword/embeddings/hybrid) built by the orchestrator (models still query via `grep_files`).
|
||||
- Multiple Verifiers with majority/weighted consensus.
|
||||
- Future: broadcast/multicast cross‑session calls (e.g., ask three verifiers and aggregate).
|
||||
|
||||
## 13) Why This Solves The Three Problems
|
||||
- Context: conversational compaction + durable memory with retrieval.
|
||||
- Pauses: assistant questions are bridged to a Director; the orchestrator backstops.
|
||||
- Review: Solver’s verification request is bridged to Verifier(s) with structured verdicts and remediation.
|
||||
|
||||
This keeps `codex-core` focused and leverages its strengths (streaming, tools, compaction, rollouts) while adding a small, clean cross‑session primitive to enable arbitrarily long, autonomous runs across domains.
|
||||
|
||||
---
|
||||
|
||||
## 14) End‑to‑End Example (Minimal)
|
||||
|
||||
Assume a run folder at `~/.codex/infty/run_123/`.
|
||||
|
||||
1) User objective → Solver (UserTurn)
|
||||
- User: "Write a tiny CLI that prints Fibonacci numbers and provide usage docs."
|
||||
|
||||
2) Solver starts
|
||||
- Tool: `update_plan` → steps: parse request; scaffold CLI; implement logic; write docs; verify; finalize deliverable.
|
||||
- Tool: `grep_files` → searches `artifacts/` and repo for prior art.
|
||||
|
||||
3) Solver seeks direction
|
||||
- Solver’s assistant message: “Confirm plan: binary in ./fib, args: N, output first N Fibonacci numbers; docs in memory/docs.md?”
|
||||
- Orchestrator posts a UserTurn to Director with that question and sets `final_output_json_schema` to the Director schema (strict).
|
||||
- Director’s first assistant message:
|
||||
```json
|
||||
{ "directive": "Proceed. Add tests under memory/tests.md; prefer iterative impl; expose --limit flag.", "rationale": "Keeps stack small; eases verification." }
|
||||
```
|
||||
- Orchestrator posts that reply as a UserTurn to Solver; Solver continues.
|
||||
|
||||
4) Solver implements
|
||||
- Tool: `apply_patch` → creates `artifacts/fib.rs` and a small Cargo bin, or shell scaffolding.
|
||||
- Tool: `shell` → `cargo run -- 10` to sanity check (under sandbox).
|
||||
- Writes `memory/docs.md` and `memory/tests.md`.
|
||||
|
||||
5) Solver claims completion
|
||||
- Writes `memory/claims/cli.json` (per template) referencing artifacts and docs.
|
||||
- Solver’s assistant message: “Please verify claim in memory/claims/cli.json with artifacts/fib.rs; run cargo test if present.”
|
||||
- Orchestrator posts a UserTurn to Verifier with `final_output_json_schema` set to the Verifier schema (strict). Verifier runs checks (via `shell`), returns:
|
||||
```json
|
||||
{ "verdict": "fail", "reasons": ["No tests"], "suggestions": ["Add tests covering N=1,2,10"] }
|
||||
```
|
||||
- Orchestrator posts that reply as a UserTurn to Solver; Solver iterates (adds tests, fixes issues).
|
||||
|
||||
6) Pass and finalize
|
||||
- Verifier returns `{ "verdict": "pass", … }`.
|
||||
- Orchestrator issues finalization UserTurn to Solver:
|
||||
- "Create deliverable/: include compiled bin or script, usage docs, and tests; write deliverable/README.md with run instructions; remove scratch files."
|
||||
- Solver:
|
||||
- Tool: `shell`/`apply_patch` → builds `deliverable/` with README and artifacts.
|
||||
- Assistant message: "Deliverable ready at ~/.codex/infty/run_123/deliverable/."
|
||||
|
||||
7) Orchestrator returns the final path to the user.
|
||||
325
codex-rs/infty2.md
Normal file
325
codex-rs/infty2.md
Normal file
@@ -0,0 +1,325 @@
|
||||
# Infty v2 - Minimal Cross-Session Loop
|
||||
|
||||
Goal: collapse the orchestration to three composable primitives while preserving the existing flow.
|
||||
|
||||
- spawn: create a role session with base instructions + config
|
||||
- await: wait for the assistant message that ends the user turn
|
||||
- forward: inject an assistant message as a user message in another session
|
||||
|
||||
The rest of the orchestrator becomes a tiny router that parses the Solver's signal and calls these helpers.
|
||||
|
||||
---
|
||||
|
||||
## Design Overview
|
||||
|
||||
We build a thin, reusable facade over `codex-core`'s cross-session utilities. This facade is role- and run-aware so callers don't need to handle `ConversationId` bookkeeping.
|
||||
|
||||
Key types from `codex-core::cross_session` that we lean on:
|
||||
|
||||
- `CrossSessionHub` - registers sessions and routes messages across them
|
||||
- `PostUserTurnRequest` - payload to submit text to a session
|
||||
- `TurnHandle` - handle for a turn (used to await the assistant)
|
||||
- `AssistantMessage` - the first assistant message for a turn
|
||||
- `SessionEventStream` - event stream for activity/idle timeouts
|
||||
|
||||
In `codex-infty`, we expose tiny helpers that wrap these primitives in a role-centric API.
|
||||
[director.md](codex-infty/src/prompts/director.md)
|
||||
---
|
||||
|
||||
## Minimal API (Facade)
|
||||
|
||||
Proposed module: `codex-infty/src/session.rs` (or fold into `orchestrator.rs` if preferred). Names shown here as free functions; methods on a small struct are also fine.
|
||||
|
||||
```rust
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use codex_core::{ConversationManager, NewConversation};
|
||||
use codex_core::config::Config;
|
||||
use codex_core::cross_session::{
|
||||
CrossSessionHub, PostUserTurnRequest, RoleOrId, TurnHandle, AssistantMessage,
|
||||
};
|
||||
use codex_protocol::ConversationId;
|
||||
|
||||
/// Opaque role session reference used by the orchestrator.
|
||||
#[derive(Clone)]
|
||||
pub struct RoleSession {
|
||||
pub role: String,
|
||||
pub conversation_id: ConversationId,
|
||||
pub conversation: Arc<codex_core::CodexConversation>,
|
||||
}
|
||||
|
||||
/// 1) Spawn a role session with base instructions applied.
|
||||
pub async fn spawn(
|
||||
hub: Arc<CrossSessionHub>,
|
||||
manager: &ConversationManager,
|
||||
run_id: &str,
|
||||
role: &str,
|
||||
mut config: Config,
|
||||
rollout_dir: impl Into<std::path::PathBuf>,
|
||||
ensure_instructions: impl FnOnce(&str, &mut Config),
|
||||
) -> Result<RoleSession> {
|
||||
config.cwd = rollout_dir.into();
|
||||
ensure_instructions(role, &mut config);
|
||||
let created: NewConversation = manager
|
||||
.new_conversation_with_cross_session(
|
||||
config,
|
||||
codex_core::CrossSessionSpawnParams {
|
||||
hub: Arc::clone(&hub),
|
||||
run_id: Some(run_id.to_string()),
|
||||
role: Some(role.to_string()),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(RoleSession {
|
||||
role: role.to_string(),
|
||||
conversation_id: created.conversation_id,
|
||||
conversation: created.conversation,
|
||||
})
|
||||
}
|
||||
|
||||
/// 2a) Post a user turn to a role.
|
||||
pub async fn post(
|
||||
hub: &CrossSessionHub,
|
||||
run_id: &str,
|
||||
role: &str,
|
||||
text: impl Into<String>,
|
||||
final_output_json_schema: Option<Value>,
|
||||
) -> Result<TurnHandle, codex_core::cross_session::CrossSessionError> {
|
||||
hub.post_user_turn(PostUserTurnRequest {
|
||||
target: RoleOrId::RunRole { run_id: run_id.to_string(), role: role.to_string() },
|
||||
text: text.into(),
|
||||
final_output_json_schema,
|
||||
}).await
|
||||
}
|
||||
|
||||
/// 2b) Await the first assistant message for this turn.
|
||||
pub async fn await_first(
|
||||
hub: &CrossSessionHub,
|
||||
handle: &TurnHandle,
|
||||
timeout: Duration,
|
||||
) -> Result<AssistantMessage, codex_core::cross_session::CrossSessionError> {
|
||||
hub.await_first_assistant(handle, timeout).await
|
||||
}
|
||||
|
||||
/// 2c) Await with idle timeout that resets on activity for this submission id.
|
||||
/// (Move the existing codex-infty implementation here verbatim.)
|
||||
```
|
||||
|
||||
```rust
|
||||
pub async fn await_first_idle(
|
||||
hub: &CrossSessionHub,
|
||||
handle: &TurnHandle,
|
||||
idle_timeout: Duration,
|
||||
) -> Result<AssistantMessage> {
|
||||
use anyhow::{anyhow, bail};
|
||||
use codex_core::protocol::EventMsg;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::StreamExt as _;
|
||||
|
||||
let mut events = hub.stream_events(handle.conversation_id())?;
|
||||
let wait_first = hub.await_first_assistant(handle, idle_timeout);
|
||||
tokio::pin!(wait_first);
|
||||
|
||||
let idle = tokio::time::sleep(idle_timeout);
|
||||
tokio::pin!(idle);
|
||||
|
||||
let sub_id = handle.submission_id().to_string();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = &mut wait_first => { return res.map_err(|e| anyhow!(e)); }
|
||||
maybe_event = events.next() => {
|
||||
let Some(ev) = maybe_event else { bail!(codex_core::cross_session::CrossSessionError::SessionClosed); };
|
||||
if ev.event.id == sub_id {
|
||||
if let EventMsg::Error(err) = &ev.event.msg { bail!(anyhow!(err.message.clone())); }
|
||||
idle.as_mut().reset(Instant::now() + idle_timeout);
|
||||
}
|
||||
}
|
||||
_ = &mut idle => { bail!(codex_core::cross_session::CrossSessionError::AwaitTimeout(idle_timeout)); }
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```rust
|
||||
/// 3) Forward an assistant's content as a user message to another role.
|
||||
pub async fn forward_assistant(
|
||||
hub: &CrossSessionHub,
|
||||
run_id: &str,
|
||||
target_role: &str,
|
||||
assistant: &AssistantMessage,
|
||||
timeout: Duration,
|
||||
final_output_json_schema: Option<Value>,
|
||||
) -> Result<AssistantMessage> {
|
||||
let handle = post(
|
||||
hub,
|
||||
run_id,
|
||||
target_role,
|
||||
assistant.message.message.clone(),
|
||||
final_output_json_schema,
|
||||
).await?;
|
||||
Ok(await_first(hub, &handle, timeout).await?)
|
||||
}
|
||||
|
||||
/// Convenience: do both post + await in one call.
|
||||
pub async fn call(
|
||||
hub: &CrossSessionHub,
|
||||
run_id: &str,
|
||||
role: &str,
|
||||
text: impl Into<String>,
|
||||
timeout: Duration,
|
||||
final_output_json_schema: Option<Value>,
|
||||
) -> Result<AssistantMessage> {
|
||||
let handle = post(hub, run_id, role, text, final_output_json_schema).await?;
|
||||
Ok(await_first(hub, &handle, timeout).await?)
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
- `await_first_idle` is the ergonomic default in Infty because it handles streaming with activity-based resets.
|
||||
- The facade leaves JSON schema optional and role-addressing consistent with `RunRole { run_id, role }`.
|
||||
|
||||
---
|
||||
|
||||
## Orchestrator Main Loop Becomes Tiny
|
||||
|
||||
Once the three operations exist, the loop reduces to routing:
|
||||
|
||||
```rust
|
||||
// Pseudocode using the facade
|
||||
let mut solver_ev = hub.stream_events(sessions.solver.conversation_id)?;
|
||||
|
||||
if let Some(objective) = options.objective.as_deref() {
|
||||
post(&hub, &run_id, &sessions.solver.role, objective, Some(solver_signal_schema())).await?;
|
||||
}
|
||||
|
||||
loop {
|
||||
let ev = solver_ev.next().await.ok_or_else(|| anyhow::anyhow!("solver closed"))?;
|
||||
if let EventMsg::AgentMessage(agent) = &ev.event.msg {
|
||||
if let Some(signal) = parse_solver_signal(&agent.message) {
|
||||
match signal {
|
||||
SolverSignal::DirectionRequest { prompt: Some(p) } => {
|
||||
let req = serde_json::to_string(&DirectionRequestPayload {
|
||||
kind: "direction_request",
|
||||
prompt: &p,
|
||||
objective: options.objective.as_deref(),
|
||||
})?;
|
||||
let directive = call(&hub, &run_id, &sessions.director.role, req, options.director_timeout, Some(directive_response_schema())).await?;
|
||||
let _ = forward_assistant(&hub, &run_id, &sessions.solver.role, &directive, std::time::Duration::from_secs(5), Some(solver_signal_schema())).await?;
|
||||
}
|
||||
SolverSignal::VerificationRequest { claim_path: Some(path), notes } => {
|
||||
let req = serde_json::to_string(&VerificationRequestPayload {
|
||||
kind: "verification_request",
|
||||
claim_path: &path,
|
||||
notes: notes.as_deref(),
|
||||
objective: options.objective.as_deref(),
|
||||
})?;
|
||||
let mut verdicts = Vec::new();
|
||||
for v in &sessions.verifiers {
|
||||
let verdict = call(&hub, &run_id, &v.role, &req, options.verifier_timeout, Some(verifier_verdict_schema())).await?;
|
||||
verdicts.push((v.role.clone(), parse_json_struct::<VerifierVerdict>(&verdict.message.message)?));
|
||||
}
|
||||
let summary = aggregate_verdicts(verdicts);
|
||||
let _ = post(&hub, &run_id, &sessions.solver.role, serde_json::to_string(&summary)?, Some(solver_signal_schema())).await?;
|
||||
}
|
||||
SolverSignal::FinalDelivery { deliverable_path: Some(path), summary } => {
|
||||
let deliverable = resolve_deliverable_path(sessions.store.path(), &path)?;
|
||||
return Ok(RunOutcome { run_id, deliverable_path: deliverable, summary, raw_message: agent.message.clone() });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Everything above already exists in `codex-infty` today; the facade simply standardizes the small operations so the loop reads linearly.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
1) Extract helpers
|
||||
- Add `session.rs` with `spawn`, `post`, `await_first`, `await_first_idle`, `forward_assistant`, `call`.
|
||||
- Move the existing `await_first_assistant_idle` body from `orchestrator.rs` to this module (exported).
|
||||
- Re-export from `lib.rs` if desirable for external callers.
|
||||
|
||||
2) Adopt helpers in `orchestrator.rs`
|
||||
- Replace `post_to_role`, `await_first_assistant`, `relay_assistant_to_role`, and `call_role` with the facade functions.
|
||||
- Keep signal parsing and run-store logic; delete glue code that becomes redundant.
|
||||
|
||||
3) Keep role spawn/resume minimal
|
||||
- Inline `spawn_role_session` and `resume_role_session` to call `session::spawn(...)` with `prompts::ensure_instructions`.
|
||||
- Preserve persistence of rollout/config paths via `RunStore`.
|
||||
|
||||
4) Preserve JSON schema guarantees
|
||||
- Pass schemas through `post`/`call`/`forward_assistant` exactly as today:
|
||||
- Solver outbound: `solver_signal_schema()`
|
||||
- Director outbound: `directive_response_schema()`
|
||||
- Verifier outbound: `verifier_verdict_schema()`
|
||||
- Finalization: `final_delivery_schema()` for the last probe
|
||||
|
||||
5) Progress reporting stays orthogonal
|
||||
- Where the orchestrator previously called `progress.*`, keep those calls around the facade usage (no change to the trait).
|
||||
|
||||
6) Tests and docs
|
||||
- Unit-test the facade with a tiny harness that posts to a mock/run role and awaits the first assistant.
|
||||
- Update README examples to use `call` and `forward_assistant` for clarity.
|
||||
|
||||
---
|
||||
|
||||
## Snippets to Drop In
|
||||
|
||||
- Posting user input and awaiting the assistant with idle timeout:
|
||||
|
||||
```rust
|
||||
let handle = session::post(hub, &run_id, &role, user_text, schema).await?;
|
||||
let assistant = session::await_first_idle(hub, &handle, std::time::Duration::from_secs(120)).await?;
|
||||
```
|
||||
|
||||
- Forwarding an assistant to another role:
|
||||
|
||||
```rust
|
||||
let reply = session::forward_assistant(hub, &run_id, &target_role, &assistant, std::time::Duration::from_secs(60), target_schema).await?;
|
||||
```
|
||||
|
||||
- Spawning a session with base instructions:
|
||||
|
||||
```rust
|
||||
let solver = session::spawn(
|
||||
Arc::clone(&hub),
|
||||
&conversation_manager,
|
||||
&run_id,
|
||||
"solver",
|
||||
solver_cfg.clone(),
|
||||
run_path, // becomes cfg.cwd
|
||||
|role, cfg| prompts::ensure_instructions(role, cfg),
|
||||
).await?;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why This Simplifies Things
|
||||
|
||||
- One mental model: "post -> await -> forward" across roles.
|
||||
- Orchestrator logic is a small, readable router.
|
||||
- Cross-session reliability remains in one place (the hub).
|
||||
- Tests become surgical: assert an assistant message is forwarded or a schema is respected.
|
||||
|
||||
---
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
- All current public behavior stays the same.
|
||||
- `InftyOrchestrator` public methods keep signatures; they are implemented in terms of the facade.
|
||||
- No changes to `codex-core` types or wire protocol.
|
||||
|
||||
---
|
||||
|
||||
## Optional Follow-Ups
|
||||
|
||||
- Consider upstreaming `await_first_idle` into `codex-core` so others can reuse it outside Infty.
|
||||
- Add typed wrappers for JSON payloads (newtypes) to reduce `serde_json::Value` usage at call sites.
|
||||
- Provide a tiny `SessionRouter` example crate to demonstrate building custom flows with these primitives.
|
||||
@@ -49,7 +49,7 @@ async fn main() -> Result<()> {
|
||||
// Spawn the subprocess and connect the client.
|
||||
let program = args.remove(0);
|
||||
let env = None;
|
||||
let client = McpClient::new_stdio_client(program, args, env)
|
||||
let client = McpClient::new_stdio_client(program, args, env, &[], None)
|
||||
.await
|
||||
.with_context(|| format!("failed to spawn subprocess: {original_args:?}"))?;
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
@@ -86,19 +87,26 @@ impl McpClient {
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
env_vars: &[String],
|
||||
cwd: Option<PathBuf>,
|
||||
) -> std::io::Result<Self> {
|
||||
let mut child = Command::new(program)
|
||||
let mut command = Command::new(program);
|
||||
command
|
||||
.args(args)
|
||||
.env_clear()
|
||||
.envs(create_env_for_mcp_server(env))
|
||||
.envs(create_env_for_mcp_server(env, env_vars))
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::null())
|
||||
// As noted in the `kill_on_drop` documentation, the Tokio runtime makes
|
||||
// a "best effort" to reap-after-exit to avoid zombie processes, but it
|
||||
// is not a guarantee.
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
.kill_on_drop(true);
|
||||
if let Some(cwd) = cwd {
|
||||
command.current_dir(cwd);
|
||||
}
|
||||
|
||||
let mut child = command.spawn()?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
@@ -447,12 +455,16 @@ const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
/// `config.toml`.
|
||||
fn create_env_for_mcp_server(
|
||||
extra_env: Option<HashMap<String, String>>,
|
||||
env_vars: &[String],
|
||||
) -> HashMap<String, String> {
|
||||
DEFAULT_ENV_VARS
|
||||
.iter()
|
||||
.filter_map(|var| match std::env::var(var) {
|
||||
Ok(value) => Some((var.to_string(), value)),
|
||||
Err(_) => None,
|
||||
.copied()
|
||||
.chain(env_vars.iter().map(String::as_str))
|
||||
.filter_map(|var| {
|
||||
std::env::var(var)
|
||||
.ok()
|
||||
.map(|value| (var.to_string(), value))
|
||||
})
|
||||
.chain(extra_env.unwrap_or_default())
|
||||
.collect::<HashMap<_, _>>()
|
||||
@@ -462,14 +474,36 @@ fn create_env_for_mcp_server(
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn set_env_var(key: &str, value: &str) {
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_env_var(key: &str) {
|
||||
unsafe {
|
||||
std::env::remove_var(key);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_env_for_mcp_server() {
|
||||
let env_var = "USER";
|
||||
let env_var_existing_value = std::env::var(env_var).unwrap_or_default();
|
||||
let env_var_new_value = format!("{env_var_existing_value}-extra");
|
||||
let extra_env = HashMap::from([(env_var.to_owned(), env_var_new_value.clone())]);
|
||||
let mcp_server_env = create_env_for_mcp_server(Some(extra_env));
|
||||
let mcp_server_env = create_env_for_mcp_server(Some(extra_env), &[]);
|
||||
assert!(mcp_server_env.contains_key("PATH"));
|
||||
assert_eq!(Some(&env_var_new_value), mcp_server_env.get(env_var));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_env_for_mcp_server_includes_extra_whitelisted_vars() {
|
||||
let custom_var = "CUSTOM_TEST_VAR";
|
||||
let value = "value".to_string();
|
||||
set_env_var(custom_var, &value);
|
||||
let mcp_server_env = create_env_for_mcp_server(None, &[custom_var.to_string()]);
|
||||
assert_eq!(Some(&value), mcp_server_env.get(custom_var));
|
||||
remove_env_var(custom_var);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,6 +178,7 @@ async fn run_codex_tool_session_inner(
|
||||
cwd,
|
||||
call_id,
|
||||
reason: _,
|
||||
parsed_cmd,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
@@ -188,6 +189,7 @@ async fn run_codex_tool_session_inner(
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
parsed_cmd,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_protocol::parse_command::ParsedCommand;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
@@ -35,6 +36,7 @@ pub struct ExecApprovalElicitRequestParams {
|
||||
pub codex_call_id: String,
|
||||
pub codex_command: Vec<String>,
|
||||
pub codex_cwd: PathBuf,
|
||||
pub codex_parsed_cmd: Vec<ParsedCommand>,
|
||||
}
|
||||
|
||||
// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See:
|
||||
@@ -56,6 +58,7 @@ pub(crate) async fn handle_exec_approval_request(
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
call_id: String,
|
||||
codex_parsed_cmd: Vec<ParsedCommand>,
|
||||
) {
|
||||
let escaped_command =
|
||||
shlex::try_join(command.iter().map(String::as_str)).unwrap_or_else(|_| command.join(" "));
|
||||
@@ -77,6 +80,7 @@ pub(crate) async fn handle_exec_approval_request(
|
||||
codex_call_id: call_id,
|
||||
codex_command: command,
|
||||
codex_cwd: cwd,
|
||||
codex_parsed_cmd,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::env;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::parse_command;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
@@ -176,6 +177,7 @@ fn create_expected_elicitation_request(
|
||||
shlex::try_join(command.iter().map(std::convert::AsRef::as_ref))?,
|
||||
workdir.to_string_lossy()
|
||||
);
|
||||
let codex_parsed_cmd = parse_command::parse_command(&command);
|
||||
Ok(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: elicitation_request_id,
|
||||
@@ -193,6 +195,7 @@ fn create_expected_elicitation_request(
|
||||
codex_command: command,
|
||||
codex_cwd: workdir.to_path_buf(),
|
||||
codex_call_id: "call1234".to_string(),
|
||||
codex_parsed_cmd,
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ pub struct OtelEventMetadata {
|
||||
conversation_id: ConversationId,
|
||||
auth_mode: Option<String>,
|
||||
account_id: Option<String>,
|
||||
account_email: Option<String>,
|
||||
model: String,
|
||||
slug: String,
|
||||
log_user_prompts: bool,
|
||||
@@ -46,11 +47,13 @@ pub struct OtelEventManager {
|
||||
}
|
||||
|
||||
impl OtelEventManager {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
conversation_id: ConversationId,
|
||||
model: &str,
|
||||
slug: &str,
|
||||
account_id: Option<String>,
|
||||
account_email: Option<String>,
|
||||
auth_mode: Option<AuthMode>,
|
||||
log_user_prompts: bool,
|
||||
terminal_type: String,
|
||||
@@ -60,6 +63,7 @@ impl OtelEventManager {
|
||||
conversation_id,
|
||||
auth_mode: auth_mode.map(|m| m.to_string()),
|
||||
account_id,
|
||||
account_email,
|
||||
model: model.to_owned(),
|
||||
slug: slug.to_owned(),
|
||||
log_user_prompts,
|
||||
@@ -98,6 +102,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -136,6 +141,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -205,6 +211,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -226,6 +233,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -240,6 +248,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -262,6 +271,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -286,6 +296,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -320,6 +331,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -343,6 +355,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -383,7 +396,8 @@ impl OtelEventManager {
|
||||
conversation.id = %self.metadata.conversation_id,
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.account_id= self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -408,6 +422,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
@@ -437,6 +452,7 @@ impl OtelEventManager {
|
||||
app.version = %self.metadata.app_version,
|
||||
auth_mode = self.metadata.auth_mode,
|
||||
user.account_id = self.metadata.account_id,
|
||||
user.email = self.metadata.account_email,
|
||||
terminal.type = %self.metadata.terminal_type,
|
||||
model = %self.metadata.model,
|
||||
slug = %self.metadata.slug,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::path::PathBuf;
|
||||
use ts_rs::TS;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, TS)]
|
||||
@@ -8,6 +9,11 @@ pub enum ParsedCommand {
|
||||
Read {
|
||||
cmd: String,
|
||||
name: String,
|
||||
/// (Best effort) Path to the file being read by the command. When
|
||||
/// possible, this is an absolute path, though when relative, it should
|
||||
/// be resolved against the `cwd`` that will be used to run the command
|
||||
/// to derive the absolute path.
|
||||
path: PathBuf,
|
||||
},
|
||||
ListFiles {
|
||||
cmd: String,
|
||||
|
||||
@@ -21,6 +21,8 @@ use crate::num_format::format_with_separators;
|
||||
use crate::parse_command::ParsedCommand;
|
||||
use crate::plan_tool::UpdatePlanArgs;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::Resource as McpResource;
|
||||
use mcp_types::ResourceTemplate as McpResourceTemplate;
|
||||
use mcp_types::Tool as McpTool;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -1178,6 +1180,7 @@ pub struct ExecApprovalRequestEvent {
|
||||
/// Optional human-readable reason for the approval (e.g. retry without sandbox).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reason: Option<String>,
|
||||
pub parsed_cmd: Vec<ParsedCommand>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
|
||||
@@ -1248,6 +1251,10 @@ pub struct GetHistoryEntryResponseEvent {
|
||||
pub struct McpListToolsResponseEvent {
|
||||
/// Fully qualified tool name -> tool definition.
|
||||
pub tools: std::collections::HashMap<String, McpTool>,
|
||||
/// Known resources grouped by server name.
|
||||
pub resources: std::collections::HashMap<String, Vec<McpResource>>,
|
||||
/// Known resource templates grouped by server name.
|
||||
pub resource_templates: std::collections::HashMap<String, Vec<McpResourceTemplate>>,
|
||||
/// Authentication status for each configured MCP server.
|
||||
pub auth_statuses: std::collections::HashMap<String, McpAuthStatus>,
|
||||
}
|
||||
|
||||
@@ -57,5 +57,7 @@ urlencoding = { workspace = true }
|
||||
webbrowser = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
escargot = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Error;
|
||||
@@ -6,11 +7,14 @@ use codex_protocol::protocol::McpAuthStatus;
|
||||
use reqwest::Client;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::has_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
|
||||
@@ -21,6 +25,8 @@ pub async fn determine_streamable_http_auth_status(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<McpAuthStatus> {
|
||||
if bearer_token_env_var.is_some() {
|
||||
@@ -31,7 +37,9 @@ pub async fn determine_streamable_http_auth_status(
|
||||
return Ok(McpAuthStatus::OAuth);
|
||||
}
|
||||
|
||||
match supports_oauth_login(url).await {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
|
||||
match supports_oauth_login_with_headers(url, &default_headers).await {
|
||||
Ok(true) => Ok(McpAuthStatus::NotLoggedIn),
|
||||
Ok(false) => Ok(McpAuthStatus::Unsupported),
|
||||
Err(error) => {
|
||||
@@ -45,8 +53,13 @@ pub async fn determine_streamable_http_auth_status(
|
||||
|
||||
/// Attempt to determine whether a streamable HTTP MCP server advertises OAuth login.
|
||||
pub async fn supports_oauth_login(url: &str) -> Result<bool> {
|
||||
supports_oauth_login_with_headers(url, &HeaderMap::new()).await
|
||||
}
|
||||
|
||||
async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result<bool> {
|
||||
let base_url = Url::parse(url)?;
|
||||
let client = Client::builder().timeout(DISCOVERY_TIMEOUT).build()?;
|
||||
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT);
|
||||
let client = apply_default_headers(builder, default_headers).build()?;
|
||||
|
||||
let mut last_error: Option<Error> = None;
|
||||
for candidate_path in discovery_paths(base_url.path()) {
|
||||
|
||||
@@ -8,8 +8,17 @@ use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::JsonObject;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::ListToolsResult;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::model::RawResource;
|
||||
use rmcp::model::RawResourceTemplate;
|
||||
use rmcp::model::ReadResourceRequestParam;
|
||||
use rmcp::model::ReadResourceResult;
|
||||
use rmcp::model::Resource;
|
||||
use rmcp::model::ResourceContents;
|
||||
use rmcp::model::ResourceTemplate;
|
||||
use rmcp::model::ServerCapabilities;
|
||||
use rmcp::model::ServerInfo;
|
||||
use rmcp::model::Tool;
|
||||
@@ -20,15 +29,24 @@ use tokio::task;
|
||||
#[derive(Clone)]
|
||||
struct TestToolServer {
|
||||
tools: Arc<Vec<Tool>>,
|
||||
resources: Arc<Vec<Resource>>,
|
||||
resource_templates: Arc<Vec<ResourceTemplate>>,
|
||||
}
|
||||
|
||||
const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) {
|
||||
(tokio::io::stdin(), tokio::io::stdout())
|
||||
}
|
||||
impl TestToolServer {
|
||||
fn new() -> Self {
|
||||
let tools = vec![Self::echo_tool()];
|
||||
let resources = vec![Self::memo_resource()];
|
||||
let resource_templates = vec![Self::memo_template()];
|
||||
Self {
|
||||
tools: Arc::new(tools),
|
||||
resources: Arc::new(resources),
|
||||
resource_templates: Arc::new(resource_templates),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +69,36 @@ impl TestToolServer {
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
|
||||
fn memo_resource() -> Resource {
|
||||
let raw = RawResource {
|
||||
uri: MEMO_URI.to_string(),
|
||||
name: "example-note".to_string(),
|
||||
title: Some("Example Note".to_string()),
|
||||
description: Some("A sample MCP resource exposed for integration tests.".to_string()),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
size: None,
|
||||
icons: None,
|
||||
};
|
||||
Resource::new(raw, None)
|
||||
}
|
||||
|
||||
fn memo_template() -> ResourceTemplate {
|
||||
let raw = RawResourceTemplate {
|
||||
uri_template: "memo://codex/{slug}".to_string(),
|
||||
name: "codex-memo".to_string(),
|
||||
title: Some("Codex Memo".to_string()),
|
||||
description: Some(
|
||||
"Template for memo://codex/{slug} resources used in tests.".to_string(),
|
||||
),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
};
|
||||
ResourceTemplate::new(raw, None)
|
||||
}
|
||||
|
||||
fn memo_text() -> &'static str {
|
||||
MEMO_CONTENT
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -66,6 +114,7 @@ impl ServerHandler for TestToolServer {
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_tools()
|
||||
.enable_tool_list_changed()
|
||||
.enable_resources()
|
||||
.build(),
|
||||
..ServerInfo::default()
|
||||
}
|
||||
@@ -85,6 +134,53 @@ impl ServerHandler for TestToolServer {
|
||||
}
|
||||
}
|
||||
|
||||
fn list_resources(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> impl std::future::Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
|
||||
let resources = self.resources.clone();
|
||||
async move {
|
||||
Ok(ListResourcesResult {
|
||||
resources: (*resources).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<ListResourceTemplatesResult, McpError> {
|
||||
Ok(ListResourceTemplatesResult {
|
||||
resource_templates: (*self.resource_templates).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<ReadResourceResult, McpError> {
|
||||
if uri == MEMO_URI {
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::TextResourceContents {
|
||||
uri,
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
text: Self::memo_text().to_string(),
|
||||
meta: None,
|
||||
}],
|
||||
})
|
||||
} else {
|
||||
Err(McpError::resource_not_found(
|
||||
"resource_not_found",
|
||||
Some(json!({ "uri": uri })),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
request: CallToolRequestParam,
|
||||
|
||||
@@ -18,8 +18,17 @@ use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::JsonObject;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::ListToolsResult;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::model::RawResource;
|
||||
use rmcp::model::RawResourceTemplate;
|
||||
use rmcp::model::ReadResourceRequestParam;
|
||||
use rmcp::model::ReadResourceResult;
|
||||
use rmcp::model::Resource;
|
||||
use rmcp::model::ResourceContents;
|
||||
use rmcp::model::ResourceTemplate;
|
||||
use rmcp::model::ServerCapabilities;
|
||||
use rmcp::model::ServerInfo;
|
||||
use rmcp::model::Tool;
|
||||
@@ -33,13 +42,22 @@ use tokio::task;
|
||||
#[derive(Clone)]
|
||||
struct TestToolServer {
|
||||
tools: Arc<Vec<Tool>>,
|
||||
resources: Arc<Vec<Resource>>,
|
||||
resource_templates: Arc<Vec<ResourceTemplate>>,
|
||||
}
|
||||
|
||||
const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
|
||||
impl TestToolServer {
|
||||
fn new() -> Self {
|
||||
let tools = vec![Self::echo_tool()];
|
||||
let resources = vec![Self::memo_resource()];
|
||||
let resource_templates = vec![Self::memo_template()];
|
||||
Self {
|
||||
tools: Arc::new(tools),
|
||||
resources: Arc::new(resources),
|
||||
resource_templates: Arc::new(resource_templates),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +80,36 @@ impl TestToolServer {
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
|
||||
fn memo_resource() -> Resource {
|
||||
let raw = RawResource {
|
||||
uri: MEMO_URI.to_string(),
|
||||
name: "example-note".to_string(),
|
||||
title: Some("Example Note".to_string()),
|
||||
description: Some("A sample MCP resource exposed for integration tests.".to_string()),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
size: None,
|
||||
icons: None,
|
||||
};
|
||||
Resource::new(raw, None)
|
||||
}
|
||||
|
||||
fn memo_template() -> ResourceTemplate {
|
||||
let raw = RawResourceTemplate {
|
||||
uri_template: "memo://codex/{slug}".to_string(),
|
||||
name: "codex-memo".to_string(),
|
||||
title: Some("Codex Memo".to_string()),
|
||||
description: Some(
|
||||
"Template for memo://codex/{slug} resources used in tests.".to_string(),
|
||||
),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
};
|
||||
ResourceTemplate::new(raw, None)
|
||||
}
|
||||
|
||||
fn memo_text() -> &'static str {
|
||||
MEMO_CONTENT
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -77,6 +125,7 @@ impl ServerHandler for TestToolServer {
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_tools()
|
||||
.enable_tool_list_changed()
|
||||
.enable_resources()
|
||||
.build(),
|
||||
..ServerInfo::default()
|
||||
}
|
||||
@@ -96,6 +145,53 @@ impl ServerHandler for TestToolServer {
|
||||
}
|
||||
}
|
||||
|
||||
fn list_resources(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> impl std::future::Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
|
||||
let resources = self.resources.clone();
|
||||
async move {
|
||||
Ok(ListResourcesResult {
|
||||
resources: (*resources).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<ListResourceTemplatesResult, McpError> {
|
||||
Ok(ListResourceTemplatesResult {
|
||||
resource_templates: (*self.resource_templates).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<ReadResourceResult, McpError> {
|
||||
if uri == MEMO_URI {
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::TextResourceContents {
|
||||
uri,
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
text: Self::memo_text().to_string(),
|
||||
meta: None,
|
||||
}],
|
||||
})
|
||||
} else {
|
||||
Err(McpError::resource_not_found(
|
||||
"resource_not_found",
|
||||
Some(json!({ "uri": uri })),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
request: CallToolRequestParam,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::string::String;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -5,6 +6,7 @@ use std::time::Duration;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use reqwest::ClientBuilder;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
@@ -16,6 +18,8 @@ use crate::OAuthCredentialsStoreMode;
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::save_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
struct CallbackServerGuard {
|
||||
server: Arc<Server>,
|
||||
@@ -31,6 +35,8 @@ pub async fn perform_oauth_login(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
@@ -51,7 +57,10 @@ pub async fn perform_oauth_login(
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, None).await?;
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
oauth_state
|
||||
.start_authorization(&[], &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -12,11 +13,19 @@ use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::InitializeResult;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ListToolsRequestParams;
|
||||
use mcp_types::ListToolsResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use reqwest::header::HeaderMap;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::InitializeRequestParam;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::model::ReadResourceRequestParam;
|
||||
use rmcp::service::RoleClient;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::service::{self};
|
||||
@@ -38,6 +47,8 @@ use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::oauth::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
use crate::utils::convert_to_mcp;
|
||||
use crate::utils::convert_to_rmcp;
|
||||
@@ -76,6 +87,8 @@ impl RmcpClient {
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
env_vars: &[String],
|
||||
cwd: Option<PathBuf>,
|
||||
) -> io::Result<Self> {
|
||||
let program_name = program.to_string_lossy().into_owned();
|
||||
let mut command = Command::new(&program);
|
||||
@@ -84,8 +97,11 @@ impl RmcpClient {
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.env_clear()
|
||||
.envs(create_env_for_mcp_server(env))
|
||||
.envs(create_env_for_mcp_server(env, env_vars))
|
||||
.args(&args);
|
||||
if let Some(cwd) = cwd {
|
||||
command.current_dir(cwd);
|
||||
}
|
||||
|
||||
let (transport, stderr) = TokioChildProcess::builder(command)
|
||||
.stderr(Stdio::piped())
|
||||
@@ -116,12 +132,17 @@ impl RmcpClient {
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new_streamable_http_client(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token: Option<String>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
|
||||
let initial_oauth_tokens = match bearer_token {
|
||||
Some(_) => None,
|
||||
None => match load_oauth_tokens(server_name, url, store_mode) {
|
||||
@@ -132,21 +153,30 @@ impl RmcpClient {
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
|
||||
let (transport, oauth_persistor) =
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens, store_mode)
|
||||
.await?;
|
||||
let (transport, oauth_persistor) = create_oauth_transport_and_runtime(
|
||||
server_name,
|
||||
url,
|
||||
initial_tokens,
|
||||
store_mode,
|
||||
default_headers.clone(),
|
||||
)
|
||||
.await?;
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
}
|
||||
} else {
|
||||
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
|
||||
if let Some(bearer_token) = bearer_token {
|
||||
if let Some(bearer_token) = bearer_token.clone() {
|
||||
http_config = http_config.auth_header(bearer_token);
|
||||
}
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_config(http_config);
|
||||
let http_client =
|
||||
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(http_client, http_config);
|
||||
PendingTransport::StreamableHttp { transport }
|
||||
};
|
||||
Ok(Self {
|
||||
@@ -241,6 +271,54 @@ impl RmcpClient {
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
params: Option<ListResourcesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListResourcesResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
.transpose()?;
|
||||
|
||||
let fut = service.list_resources(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "resources/list").await?;
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn list_resource_templates(
|
||||
&self,
|
||||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListResourceTemplatesResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
.transpose()?;
|
||||
|
||||
let fut = service.list_resource_templates(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "resources/templates/list").await?;
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
params: ReadResourceRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ReadResourceResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params: ReadResourceRequestParam = convert_to_rmcp(params)?;
|
||||
let fut = service.read_resource(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "resources/read").await?;
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
name: String,
|
||||
@@ -276,6 +354,8 @@ impl RmcpClient {
|
||||
}
|
||||
}
|
||||
|
||||
/// This should be called after every tool call so that if a given tool call triggered
|
||||
/// a refresh of the OAuth tokens, they are persisted.
|
||||
async fn persist_oauth_tokens(&self) {
|
||||
if let Some(runtime) = self.oauth_persistor().await
|
||||
&& let Err(error) = runtime.persist_if_needed().await
|
||||
@@ -290,11 +370,13 @@ async fn create_oauth_transport_and_runtime(
|
||||
url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
credentials_store: OAuthCredentialsStoreMode,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
OAuthPersistor,
|
||||
)> {
|
||||
let http_client = reqwest::Client::builder().build()?;
|
||||
let http_client =
|
||||
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
||||
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
|
||||
@@ -6,6 +6,10 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use mcp_types::CallToolResult;
|
||||
use reqwest::ClientBuilder;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
||||
use rmcp::service::ServiceError;
|
||||
use serde_json::Value;
|
||||
@@ -70,14 +74,86 @@ where
|
||||
|
||||
pub(crate) fn create_env_for_mcp_server(
|
||||
extra_env: Option<HashMap<String, String>>,
|
||||
env_vars: &[String],
|
||||
) -> HashMap<String, String> {
|
||||
DEFAULT_ENV_VARS
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(env_vars.iter().map(String::as_str))
|
||||
.filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value)))
|
||||
.chain(extra_env.unwrap_or_default())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn build_default_headers(
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
if let Some(static_headers) = http_headers {
|
||||
for (name, value) in static_headers {
|
||||
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
|
||||
Ok(name) => name,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header name `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let header_value = match HeaderValue::from_str(value.as_str()) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header value for `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(env_headers) = env_http_headers {
|
||||
for (name, env_var) in env_headers {
|
||||
if let Ok(value) = env::var(&env_var) {
|
||||
if value.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
|
||||
Ok(name) => name,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header name `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let header_value = match HeaderValue::from_str(value.as_str()) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"invalid HTTP header value read from {env_var} for `{name}`: {err}"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
pub(crate) fn apply_default_headers(
|
||||
builder: ClientBuilder,
|
||||
default_headers: &HeaderMap,
|
||||
) -> ClientBuilder {
|
||||
if default_headers.is_empty() {
|
||||
builder
|
||||
} else {
|
||||
builder.default_headers(default_headers.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
"HOME",
|
||||
@@ -112,13 +188,59 @@ mod tests {
|
||||
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
||||
use serde_json::json;
|
||||
|
||||
use serial_test::serial;
|
||||
use std::ffi::OsString;
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: String,
|
||||
original: Option<OsString>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn set(key: &str, value: &str) -> Self {
|
||||
let original = std::env::var_os(key);
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
Self {
|
||||
key: key.to_string(),
|
||||
original,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Some(value) = &self.original {
|
||||
unsafe {
|
||||
std::env::set_var(&self.key, value);
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
std::env::remove_var(&self.key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_env_honors_overrides() {
|
||||
let value = "custom".to_string();
|
||||
let env = create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])));
|
||||
let env =
|
||||
create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])), &[]);
|
||||
assert_eq!(env.get("TZ"), Some(&value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(extra_rmcp_env)]
|
||||
fn create_env_includes_additional_whitelisted_variables() {
|
||||
let custom_var = "EXTRA_RMCP_ENV";
|
||||
let value = "from-env";
|
||||
let _guard = EnvVarGuard::set(custom_var, value);
|
||||
let env = create_env_for_mcp_server(None, &[custom_var.to_string()]);
|
||||
assert_eq!(env.get(custom_var), Some(&value.to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_call_tool_result_defaults_missing_content() -> Result<()> {
|
||||
let structured_content = json!({ "key": "value" });
|
||||
|
||||
124
codex-rs/rmcp-client/tests/resources.rs
Normal file
124
codex-rs/rmcp-client/tests/resources.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use escargot::CargoBuild;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResultContents;
|
||||
use mcp_types::Resource;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use mcp_types::TextResourceContents;
|
||||
use serde_json::json;
|
||||
|
||||
const RESOURCE_URI: &str = "memo://codex/example-note";
|
||||
|
||||
fn stdio_server_bin() -> anyhow::Result<PathBuf> {
|
||||
let build = CargoBuild::new()
|
||||
.package("codex-rmcp-client")
|
||||
.bin("test_stdio_server")
|
||||
.run()?;
|
||||
Ok(build.path().to_path_buf())
|
||||
}
|
||||
|
||||
fn init_params() -> InitializeRequestParams {
|
||||
InitializeRequestParams {
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: Some(json!({})),
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-test".into(),
|
||||
version: "0.0.0-test".into(),
|
||||
title: Some("Codex rmcp resource test".into()),
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> {
|
||||
let client = RmcpClient::new_stdio_client(
|
||||
stdio_server_bin()?.into(),
|
||||
Vec::<OsString>::new(),
|
||||
None,
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.initialize(init_params(), Some(Duration::from_secs(5)))
|
||||
.await?;
|
||||
|
||||
let list = client
|
||||
.list_resources(None, Some(Duration::from_secs(5)))
|
||||
.await?;
|
||||
let memo = list
|
||||
.resources
|
||||
.iter()
|
||||
.find(|resource| resource.uri == RESOURCE_URI)
|
||||
.expect("memo resource present");
|
||||
assert_eq!(
|
||||
memo,
|
||||
&Resource {
|
||||
annotations: None,
|
||||
description: Some("A sample MCP resource exposed for integration tests.".to_string()),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
name: "example-note".to_string(),
|
||||
size: None,
|
||||
title: Some("Example Note".to_string()),
|
||||
uri: RESOURCE_URI.to_string(),
|
||||
}
|
||||
);
|
||||
let templates = client
|
||||
.list_resource_templates(None, Some(Duration::from_secs(5)))
|
||||
.await?;
|
||||
assert_eq!(
|
||||
templates,
|
||||
ListResourceTemplatesResult {
|
||||
next_cursor: None,
|
||||
resource_templates: vec![ResourceTemplate {
|
||||
annotations: None,
|
||||
description: Some(
|
||||
"Template for memo://codex/{slug} resources used in tests.".to_string()
|
||||
),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
name: "codex-memo".to_string(),
|
||||
title: Some("Codex Memo".to_string()),
|
||||
uri_template: "memo://codex/{slug}".to_string(),
|
||||
}],
|
||||
}
|
||||
);
|
||||
|
||||
let read = client
|
||||
.read_resource(
|
||||
ReadResourceRequestParams {
|
||||
uri: RESOURCE_URI.to_string(),
|
||||
},
|
||||
Some(Duration::from_secs(5)),
|
||||
)
|
||||
.await?;
|
||||
let ReadResourceResultContents::TextResourceContents(text) =
|
||||
read.contents.first().expect("resource contents present")
|
||||
else {
|
||||
panic!("expected text resource");
|
||||
};
|
||||
assert_eq!(
|
||||
text,
|
||||
&TextResourceContents {
|
||||
text: "This is a sample MCP resource served by the rmcp test server.".to_string(),
|
||||
uri: RESOURCE_URI.to_string(),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user