mirror of
https://github.com/openai/codex.git
synced 2026-02-02 23:13:37 +00:00
Compare commits
12 Commits
concurrent
...
codex/impl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d465d71955 | ||
|
|
3d1cfe31a2 | ||
|
|
d6e934f7cd | ||
|
|
0b83f2965c | ||
|
|
d4dc3b11bc | ||
|
|
bcbe02ff1d | ||
|
|
51257e2fd0 | ||
|
|
0ece374c58 | ||
|
|
f532554924 | ||
|
|
f9609cc9bf | ||
|
|
781798b4ed | ||
|
|
5bafe0dc59 |
@@ -21,7 +21,7 @@
|
||||
"settings": {
|
||||
"terminal.integrated.defaultProfile.linux": "bash"
|
||||
},
|
||||
"extensions": ["rust-lang.rust-analyzer", "tamasfe.even-better-toml"]
|
||||
"extensions": ["rust-lang.rust-analyzer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
18
.vscode/launch.json
vendored
18
.vscode/launch.json
vendored
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Cargo launch",
|
||||
"cargo": {
|
||||
"cwd": "${workspaceFolder}/codex-rs",
|
||||
"args": [
|
||||
"build",
|
||||
"--bin=codex-tui"
|
||||
]
|
||||
},
|
||||
"args": []
|
||||
}
|
||||
]
|
||||
}
|
||||
10
.vscode/settings.json
vendored
10
.vscode/settings.json
vendored
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"rust-analyzer.checkOnSave": true,
|
||||
"rust-analyzer.check.command": "clippy",
|
||||
"rust-analyzer.check.extraArgs": ["--all-features", "--tests"],
|
||||
"rust-analyzer.rustfmt.extraArgs": ["--config", "imports_granularity=Item"],
|
||||
"[rust]": {
|
||||
"editor.defaultFormatter": "rust-lang.rust-analyzer",
|
||||
"editor.formatOnSave": true,
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,3 @@
|
||||
In the codex-rs folder where the rust code lives:
|
||||
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR`. You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
|
||||
Before creating a pull request with changes to `codex-rs`, run `just fmt` (in `codex-rs` directory) to format the code and `just fix` (in `codex-rs` directory) to fix any linter issues in the code, ensure the test suite passes by running `cargo test --all-features` in the `codex-rs` directory.
|
||||
|
||||
When making individual changes prefer running tests on individual files or projects first.
|
||||
|
||||
@@ -370,26 +370,11 @@ export function isSafeCommand(
|
||||
reason: "View file with line numbers",
|
||||
group: "Reading files",
|
||||
};
|
||||
case "rg": {
|
||||
// Certain ripgrep options execute external commands or invoke other
|
||||
// processes, so we must reject them.
|
||||
const isUnsafe = command.some(
|
||||
(arg: string) =>
|
||||
UNSAFE_OPTIONS_FOR_RIPGREP_WITHOUT_ARGS.has(arg) ||
|
||||
[...UNSAFE_OPTIONS_FOR_RIPGREP_WITH_ARGS].some(
|
||||
(opt) => arg === opt || arg.startsWith(`${opt}=`),
|
||||
),
|
||||
);
|
||||
|
||||
if (isUnsafe) {
|
||||
break;
|
||||
}
|
||||
|
||||
case "rg":
|
||||
return {
|
||||
reason: "Ripgrep search",
|
||||
group: "Searching",
|
||||
};
|
||||
}
|
||||
case "find": {
|
||||
// Certain options to `find` allow executing arbitrary processes, so we
|
||||
// cannot auto-approve them.
|
||||
@@ -510,22 +495,6 @@ const UNSAFE_OPTIONS_FOR_FIND_COMMAND: ReadonlySet<string> = new Set([
|
||||
"-fprintf",
|
||||
]);
|
||||
|
||||
// Ripgrep options that are considered unsafe because they may execute
|
||||
// arbitrary commands or spawn auxiliary processes.
|
||||
const UNSAFE_OPTIONS_FOR_RIPGREP_WITH_ARGS: ReadonlySet<string> = new Set([
|
||||
// Executes an arbitrary command for each matching file.
|
||||
"--pre",
|
||||
// Allows custom hostname command which could leak environment details.
|
||||
"--hostname-bin",
|
||||
]);
|
||||
|
||||
const UNSAFE_OPTIONS_FOR_RIPGREP_WITHOUT_ARGS: ReadonlySet<string> = new Set([
|
||||
// Enables searching inside archives which triggers external decompression
|
||||
// utilities – reject out of an abundance of caution.
|
||||
"--search-zip",
|
||||
"-z",
|
||||
]);
|
||||
|
||||
// ---------------- Helper utilities for complex shell expressions -----------------
|
||||
|
||||
// A conservative allow-list of bash operators that do not, on their own, cause
|
||||
|
||||
@@ -44,14 +44,6 @@ describe("canAutoApprove()", () => {
|
||||
group: "Navigating",
|
||||
runInSandbox: false,
|
||||
});
|
||||
|
||||
// Ripgrep safe invocation.
|
||||
expect(check(["rg", "TODO"])).toEqual({
|
||||
type: "auto-approve",
|
||||
reason: "Ripgrep search",
|
||||
group: "Searching",
|
||||
runInSandbox: false,
|
||||
});
|
||||
});
|
||||
|
||||
test("simple safe commands within a `bash -lc` call", () => {
|
||||
@@ -75,24 +67,6 @@ describe("canAutoApprove()", () => {
|
||||
});
|
||||
});
|
||||
|
||||
test("ripgrep unsafe flags", () => {
|
||||
// Flags that do not take arguments
|
||||
expect(check(["rg", "--search-zip", "TODO"])).toEqual({ type: "ask-user" });
|
||||
expect(check(["rg", "-z", "TODO"])).toEqual({ type: "ask-user" });
|
||||
|
||||
// Flags that take arguments (provided separately)
|
||||
expect(check(["rg", "--pre", "cat", "TODO"])).toEqual({ type: "ask-user" });
|
||||
expect(check(["rg", "--hostname-bin", "hostname", "TODO"])).toEqual({
|
||||
type: "ask-user",
|
||||
});
|
||||
|
||||
// Flags that take arguments in = form
|
||||
expect(check(["rg", "--pre=cat", "TODO"])).toEqual({ type: "ask-user" });
|
||||
expect(check(["rg", "--hostname-bin=hostname", "TODO"])).toEqual({
|
||||
type: "ask-user",
|
||||
});
|
||||
});
|
||||
|
||||
test("bash -lc commands with unsafe redirects", () => {
|
||||
expect(check(["bash", "-lc", "echo hello > file.txt"])).toEqual({
|
||||
type: "ask-user",
|
||||
|
||||
80
codex-rs/Cargo.lock
generated
80
codex-rs/Cargo.lock
generated
@@ -399,15 +399,6 @@ version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bstr"
|
||||
version = "1.12.0"
|
||||
@@ -627,7 +618,6 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"chrono",
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"codex-chatgpt",
|
||||
@@ -638,13 +628,14 @@ dependencies = [
|
||||
"codex-login",
|
||||
"codex-mcp-server",
|
||||
"codex-tui",
|
||||
"serde",
|
||||
"indoc",
|
||||
"predicates",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -674,7 +665,6 @@ dependencies = [
|
||||
"fs2",
|
||||
"futures",
|
||||
"landlock",
|
||||
"libc",
|
||||
"maplit",
|
||||
"mcp-types",
|
||||
"mime_guess",
|
||||
@@ -686,7 +676,6 @@ dependencies = [
|
||||
"seccompiler",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"strum_macros 0.27.1",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
@@ -699,7 +688,6 @@ dependencies = [
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
"wildmatch",
|
||||
"wiremock",
|
||||
]
|
||||
@@ -798,7 +786,6 @@ name = "codex-mcp-server"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-core",
|
||||
"codex-linux-sandbox",
|
||||
"mcp-types",
|
||||
@@ -806,15 +793,10 @@ dependencies = [
|
||||
"schemars 0.8.22",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"toml 0.9.1",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -955,15 +937,6 @@ version = "0.8.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.4.2"
|
||||
@@ -1038,16 +1011,6 @@ version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctor"
|
||||
version = "0.1.26"
|
||||
@@ -1198,16 +1161,6 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8"
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "6.0.0"
|
||||
@@ -1697,16 +1650,6 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getopts"
|
||||
version = "0.2.23"
|
||||
@@ -4006,17 +3949,6 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
@@ -4924,12 +4856,6 @@ dependencies = [
|
||||
"unicode-width 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
|
||||
|
||||
[[package]]
|
||||
name = "unicase"
|
||||
version = "2.8.1"
|
||||
|
||||
@@ -27,8 +27,6 @@ codex-linux-sandbox = { path = "../linux-sandbox" }
|
||||
codex-mcp-server = { path = "../mcp-server" }
|
||||
codex-tui = { path = "../tui" }
|
||||
serde_json = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock"] }
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
@@ -38,10 +36,11 @@ tokio = { version = "1", features = [
|
||||
] }
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = "0.3.19"
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tempfile = "3"
|
||||
predicates = "3"
|
||||
tempfile = "3"
|
||||
wiremock = "0.6"
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
indoc = "2"
|
||||
|
||||
@@ -1,357 +0,0 @@
|
||||
use std::fs::File;
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::io::Write; // added for write_all / flush
|
||||
|
||||
use anyhow::Context;
|
||||
use codex_common::ApprovalModeCliArg;
|
||||
use codex_tui::Cli as TuiCli;
|
||||
|
||||
/// Attempt to handle a concurrent background run. Returns Ok(true) if a background exec
|
||||
/// process was spawned (in which case the caller should NOT start the TUI), or Ok(false)
|
||||
/// to proceed with normal interactive execution.
|
||||
pub fn maybe_spawn_concurrent(
|
||||
tui_cli: &mut TuiCli,
|
||||
root_raw_overrides: &[String],
|
||||
concurrent: bool,
|
||||
concurrent_automerge: Option<bool>,
|
||||
concurrent_branch_name: &Option<String>,
|
||||
) -> anyhow::Result<bool> {
|
||||
if !concurrent { return Ok(false); }
|
||||
|
||||
// Enforce autonomous execution conditions when running interactive mode.
|
||||
// Validate git repository presence (required for --concurrent) only if we're in interactive path.
|
||||
{
|
||||
let dir_to_check = tui_cli
|
||||
.cwd
|
||||
.clone()
|
||||
.unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
|
||||
let status = Command::new("git")
|
||||
.arg("-C")
|
||||
.arg(&dir_to_check)
|
||||
.arg("rev-parse")
|
||||
.arg("--git-dir")
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status();
|
||||
if status.as_ref().map(|s| !s.success()).unwrap_or(true) {
|
||||
eprintln!(
|
||||
"Error: --concurrent requires a git repository (directory {:?} is not managed by git).",
|
||||
dir_to_check
|
||||
);
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
let ap = tui_cli.approval_policy;
|
||||
let approval_on_failure = matches!(ap, Some(ApprovalModeCliArg::OnFailure));
|
||||
let autonomous = tui_cli.full_auto
|
||||
|| tui_cli.dangerously_bypass_approvals_and_sandbox
|
||||
|| approval_on_failure;
|
||||
if !autonomous {
|
||||
eprintln!(
|
||||
"Error: --concurrent requires autonomous mode. Use one of: --full-auto, --ask-for-approval on-failure, or --dangerously-bypass-approvals-and-sandbox."
|
||||
);
|
||||
std::process::exit(2);
|
||||
}
|
||||
if tui_cli.prompt.is_none() {
|
||||
eprintln!(
|
||||
"Error: --concurrent requires a prompt argument so the agent does not wait for interactive input."
|
||||
);
|
||||
std::process::exit(2);
|
||||
}
|
||||
|
||||
// Build exec args from interactive CLI for autonomous run without TUI (background).
|
||||
let mut exec_args: Vec<String> = Vec::new();
|
||||
if !tui_cli.images.is_empty() {
|
||||
exec_args.push("--image".into());
|
||||
exec_args.push(tui_cli.images.iter().map(|p| p.display().to_string()).collect::<Vec<_>>().join(","));
|
||||
}
|
||||
if let Some(model) = &tui_cli.model { exec_args.push("--model".into()); exec_args.push(model.clone()); }
|
||||
if let Some(profile) = &tui_cli.config_profile { exec_args.push("--profile".into()); exec_args.push(profile.clone()); }
|
||||
if let Some(sandbox) = &tui_cli.sandbox_mode { exec_args.push("--sandbox".into()); exec_args.push(format!("{sandbox:?}").to_lowercase().replace('_', "-")); }
|
||||
if tui_cli.full_auto { exec_args.push("--full-auto".into()); }
|
||||
if tui_cli.dangerously_bypass_approvals_and_sandbox { exec_args.push("--dangerously-bypass-approvals-and-sandbox".into()); }
|
||||
if tui_cli.skip_git_repo_check { exec_args.push("--skip-git-repo-check".into()); }
|
||||
for raw in root_raw_overrides { exec_args.push("-c".into()); exec_args.push(raw.clone()); }
|
||||
|
||||
// Derive a single slug (shared by worktree branch & log filename) from the prompt.
|
||||
let raw_prompt = tui_cli.prompt.as_deref().unwrap_or("");
|
||||
let snippet = raw_prompt.chars().take(32).collect::<String>();
|
||||
let mut slug: String = snippet
|
||||
.chars()
|
||||
.map(|c| if c.is_ascii_alphanumeric() { c.to_ascii_lowercase() } else { '-' })
|
||||
.collect();
|
||||
while slug.contains("--") { slug = slug.replace("--", "-"); }
|
||||
slug = slug.trim_matches('-').to_string();
|
||||
if slug.is_empty() { slug = "prompt".into(); }
|
||||
|
||||
// Determine concurrent defaults from env (no config file), then apply CLI precedence.
|
||||
let env_automerge = parse_env_bool("CONCURRENT_AUTOMERGE");
|
||||
let env_branch_name = std::env::var("CONCURRENT_BRANCH_NAME").ok();
|
||||
let effective_automerge = concurrent_automerge.or(env_automerge).unwrap_or(true);
|
||||
let user_branch_name_opt = concurrent_branch_name.clone().or(env_branch_name);
|
||||
let branch_name_effective = if let Some(bn_raw) = user_branch_name_opt.as_ref() {
|
||||
let bn_trim = bn_raw.trim();
|
||||
if bn_trim.is_empty() { format!("codex/{slug}") } else { bn_trim.to_string() }
|
||||
} else {
|
||||
format!("codex/{slug}")
|
||||
};
|
||||
|
||||
// Unique job id for this concurrent run (used for log file naming instead of slug).
|
||||
let task_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Prepare log file path early so we can write pre-spawn logs (e.g. worktree creation output) into it.
|
||||
let log_dir = match codex_base_dir() {
|
||||
Ok(base) => {
|
||||
let d = base.join("log");
|
||||
let _ = std::fs::create_dir_all(&d);
|
||||
d
|
||||
}
|
||||
Err(_) => PathBuf::from("/tmp"),
|
||||
};
|
||||
let log_path = log_dir.join(format!("codex-logs-{}.log", task_id));
|
||||
|
||||
// If user did NOT specify an explicit cwd, create an isolated git worktree.
|
||||
let mut created_worktree: Option<(PathBuf, String)> = None; // (path, branch)
|
||||
let mut original_branch: Option<String> = None;
|
||||
let mut original_commit: Option<String> = None;
|
||||
let mut pre_spawn_logs = String::new();
|
||||
if tui_cli.cwd.is_none() {
|
||||
original_branch = git_capture(["rev-parse", "--abbrev-ref", "HEAD"]).ok();
|
||||
original_commit = git_capture(["rev-parse", "HEAD"]).ok();
|
||||
match create_concurrent_worktree(&branch_name_effective) {
|
||||
Ok(Some(info)) => {
|
||||
exec_args.push("--cd".into());
|
||||
exec_args.push(info.worktree_path.display().to_string());
|
||||
created_worktree = Some((info.worktree_path, info.branch_name.clone()));
|
||||
// Keep the original git output plus a concise created line (for log file only).
|
||||
pre_spawn_logs.push_str(&info.logs);
|
||||
pre_spawn_logs.push_str(&format!(
|
||||
"Created git worktree at {} (branch {}) for concurrent run\n",
|
||||
created_worktree.as_ref().unwrap().0.display(), info.branch_name
|
||||
));
|
||||
}
|
||||
Ok(None) => {
|
||||
// Silence console noise: do not warn here to keep stdout clean; we still proceed.
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error: failed to create git worktree for --concurrent: {e}");
|
||||
eprintln!("Hint: remove or rename existing branch '{branch_name_effective}', or pass --concurrent-branch-name to choose a unique name.");
|
||||
std::process::exit(3);
|
||||
}
|
||||
}
|
||||
} else if let Some(explicit) = &tui_cli.cwd {
|
||||
exec_args.push("--cd".into());
|
||||
exec_args.push(explicit.display().to_string());
|
||||
}
|
||||
|
||||
// Prompt (safe to unwrap due to earlier validation).
|
||||
if let Some(prompt) = tui_cli.prompt.clone() { exec_args.push(prompt); }
|
||||
|
||||
// Create (or truncate) the log file and write any pre-spawn logs we captured.
|
||||
let file = match File::create(&log_path) {
|
||||
Ok(mut f) => {
|
||||
if !pre_spawn_logs.is_empty() {
|
||||
let _ = f.write_all(pre_spawn_logs.as_bytes());
|
||||
let _ = f.flush();
|
||||
}
|
||||
f
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to create log file {}: {e}. Falling back to interactive mode.", log_path.display());
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
match File::create(&log_path) {
|
||||
Ok(file) => {
|
||||
let file_err = file.try_clone().ok();
|
||||
let mut cmd = Command::new(
|
||||
std::env::current_exe().unwrap_or_else(|_| PathBuf::from("codex"))
|
||||
);
|
||||
cmd.arg("exec");
|
||||
for a in &exec_args { cmd.arg(a); }
|
||||
// Provide metadata for auto merge if we created a worktree.
|
||||
if let Some((wt_path, branch)) = &created_worktree {
|
||||
if effective_automerge { cmd.env("CODEX_CONCURRENT_AUTOMERGE", "1"); }
|
||||
cmd.env("CODEX_CONCURRENT_BRANCH", branch);
|
||||
cmd.env("CODEX_CONCURRENT_WORKTREE", wt_path);
|
||||
if let Some(ob) = &original_branch { cmd.env("CODEX_ORIGINAL_BRANCH", ob); }
|
||||
if let Some(oc) = &original_commit { cmd.env("CODEX_ORIGINAL_COMMIT", oc); }
|
||||
if let Ok(orig_root) = std::env::current_dir() { cmd.env("CODEX_ORIGINAL_ROOT", orig_root); }
|
||||
}
|
||||
// Provide task id so child process can emit token_count updates to tasks.jsonl.
|
||||
cmd.env("CODEX_TASK_ID", &task_id);
|
||||
cmd.stdout(Stdio::from(file));
|
||||
if let Some(f2) = file_err { cmd.stderr(Stdio::from(f2)); }
|
||||
match cmd.spawn() {
|
||||
Ok(child) => {
|
||||
// Human-friendly multi-line output with bold headers.
|
||||
let branch_val = created_worktree.as_ref().map(|(_, b)| b.as_str()).unwrap_or("(none)");
|
||||
let worktree_val = created_worktree
|
||||
.as_ref()
|
||||
.map(|(p, _)| p.display().to_string())
|
||||
.unwrap_or_else(|| "(original cwd)".to_string());
|
||||
// ANSI escape for bold: \x1b[1m ... \x1b[0m
|
||||
println!("\x1b[1mTask ID:\x1b[0m {}", task_id);
|
||||
println!("\x1b[1mPID:\x1b[0m {}", child.id());
|
||||
println!("\x1b[1mBranch:\x1b[0m {}", branch_val);
|
||||
println!("\x1b[1mWorktree:\x1b[0m {}", worktree_val);
|
||||
println!("\x1b[1mState:\x1b[0m started");
|
||||
// Use bold bright magenta (95) for actionable follow-up commands.
|
||||
println!("\nMonitor all tasks: \x1b[1;95mcodex tasks ls\x1b[0m");
|
||||
println!("Watch this task: \x1b[1;95mcodex logs {} -f\x1b[0m", task_id);
|
||||
|
||||
// Record task metadata to CODEX_HOME/tasks.jsonl (JSON Lines file).
|
||||
let record_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
if let Ok(base) = codex_base_dir() {
|
||||
let tasks_path = base.join("tasks.jsonl");
|
||||
let record = serde_json::json!({
|
||||
"task_id": task_id,
|
||||
"pid": child.id(),
|
||||
"worktree": created_worktree.as_ref().map(|(p, _)| p.display().to_string()),
|
||||
"branch": created_worktree.as_ref().map(|(_, b)| b.clone()),
|
||||
"original_branch": original_branch,
|
||||
"original_commit": original_commit,
|
||||
"log_path": log_path.display().to_string(),
|
||||
"prompt": raw_prompt,
|
||||
"model": tui_cli.model.clone(),
|
||||
"start_time": record_time,
|
||||
"automerge": effective_automerge,
|
||||
"explicit_branch_name": user_branch_name_opt,
|
||||
"token_count": serde_json::Value::Null,
|
||||
"state": "started",
|
||||
});
|
||||
if let Ok(mut f) = std::fs::OpenOptions::new().create(true).append(true).open(&tasks_path) {
|
||||
use std::io::Write;
|
||||
if let Err(e) = writeln!(f, "{}", record.to_string()) {
|
||||
eprintln!("Warning: failed writing task record to {}: {e}", tasks_path.display());
|
||||
}
|
||||
} else {
|
||||
eprintln!("Warning: could not open tasks log file at {}", tasks_path.display());
|
||||
}
|
||||
}
|
||||
return Ok(true); // background spawned
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to start background exec: {e}. Falling back to interactive mode.");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"Failed to create log file {}: {e}. Falling back to interactive mode.",
|
||||
log_path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Return the base Codex directory under the user's home (~/.codex), creating it if necessary.
|
||||
fn codex_base_dir() -> anyhow::Result<PathBuf> {
|
||||
if let Ok(val) = std::env::var("CODEX_HOME") {
|
||||
if !val.is_empty() {
|
||||
return Ok(PathBuf::from(val).canonicalize()?);
|
||||
}
|
||||
}
|
||||
let home = std::env::var_os("HOME").ok_or_else(|| anyhow::anyhow!("Could not find home directory"))?;
|
||||
let base = PathBuf::from(home).join(".codex");
|
||||
std::fs::create_dir_all(&base)?;
|
||||
Ok(base)
|
||||
}
|
||||
|
||||
/// Attempt to create a git worktree for an isolated concurrent run capturing git output.
|
||||
struct WorktreeInfo { worktree_path: PathBuf, branch_name: String, logs: String }
|
||||
fn create_concurrent_worktree(branch_name: &str) -> anyhow::Result<Option<WorktreeInfo>> {
|
||||
// Determine repository root.
|
||||
let output = Command::new("git").arg("rev-parse").arg("--show-toplevel").output();
|
||||
let repo_root = match output {
|
||||
Ok(out) if out.status.success() => {
|
||||
let s = String::from_utf8_lossy(&out.stdout).trim().to_string();
|
||||
if s.is_empty() { return Ok(None); }
|
||||
PathBuf::from(s)
|
||||
}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
// Derive repo name from root directory.
|
||||
let repo_name = repo_root
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("repo");
|
||||
|
||||
// Fast-fail if branch already exists.
|
||||
if Command::new("git")
|
||||
.current_dir(&repo_root)
|
||||
.arg("rev-parse")
|
||||
.arg("--verify")
|
||||
.arg(branch_name)
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.map(|s| s.success())
|
||||
.unwrap_or(false) {
|
||||
anyhow::bail!("branch '{branch_name}' already exists");
|
||||
}
|
||||
|
||||
// Construct worktree directory under ~/.codex/worktrees/<repo_name>/.
|
||||
let base_dir = codex_base_dir()?.join("worktrees").join(repo_name);
|
||||
std::fs::create_dir_all(&base_dir)?;
|
||||
let mut worktree_path = base_dir.join(branch_name.replace('/', "-"));
|
||||
|
||||
if worktree_path.exists() {
|
||||
for i in 1..1000 {
|
||||
let candidate = base_dir.join(format!("{}-{}", branch_name.replace('/', "-"), i));
|
||||
if !candidate.exists() { worktree_path = candidate; break; }
|
||||
}
|
||||
}
|
||||
|
||||
// Run git worktree add capturing output (stdout+stderr).
|
||||
let add_out = Command::new("git")
|
||||
.current_dir(&repo_root)
|
||||
.arg("worktree")
|
||||
.arg("add")
|
||||
.arg("-b")
|
||||
.arg(&branch_name)
|
||||
.arg(&worktree_path)
|
||||
.arg("HEAD")
|
||||
.output()?;
|
||||
if !add_out.status.success() {
|
||||
anyhow::bail!("git worktree add failed with status {}", add_out.status);
|
||||
}
|
||||
let mut logs = String::new();
|
||||
if !add_out.stdout.is_empty() { logs.push_str(&String::from_utf8_lossy(&add_out.stdout)); }
|
||||
if !add_out.stderr.is_empty() { logs.push_str(&String::from_utf8_lossy(&add_out.stderr)); }
|
||||
|
||||
Ok(Some(WorktreeInfo { worktree_path, branch_name: branch_name.to_string(), logs }))
|
||||
}
|
||||
|
||||
/// Helper: capture trimmed stdout of a git command.
|
||||
fn git_capture<I, S>(args: I) -> anyhow::Result<String>
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
let mut cmd = Command::new("git");
|
||||
for a in args { cmd.arg(a.as_ref()); }
|
||||
let out = cmd.output().context("running git command")?;
|
||||
if !out.status.success() { anyhow::bail!("git command failed"); }
|
||||
Ok(String::from_utf8_lossy(&out.stdout).trim().to_string())
|
||||
}
|
||||
|
||||
/// Parse common boolean environment variable representations.
|
||||
fn parse_env_bool(name: &str) -> Option<bool> {
|
||||
let raw = std::env::var(name).ok()?;
|
||||
let lower = raw.to_ascii_lowercase();
|
||||
match lower.as_str() {
|
||||
"1" | "true" | "yes" | "on" => Some(true),
|
||||
"0" | "false" | "no" | "off" => Some(false),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
pub struct InspectCli {
|
||||
/// Task identifier (full/short task id or exact branch name)
|
||||
pub id: String,
|
||||
/// Output JSON instead of human table
|
||||
#[arg(long)]
|
||||
pub json: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RawRecord {
|
||||
task_id: Option<String>,
|
||||
pid: Option<u64>,
|
||||
worktree: Option<String>,
|
||||
branch: Option<String>,
|
||||
original_branch: Option<String>,
|
||||
original_commit: Option<String>,
|
||||
log_path: Option<String>,
|
||||
prompt: Option<String>,
|
||||
model: Option<String>,
|
||||
start_time: Option<u64>,
|
||||
update_time: Option<u64>,
|
||||
token_count: Option<serde_json::Value>,
|
||||
state: Option<String>,
|
||||
completion_time: Option<u64>,
|
||||
end_time: Option<u64>,
|
||||
automerge: Option<bool>,
|
||||
explicit_branch_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, Default, Clone)]
|
||||
struct TaskFull {
|
||||
task_id: String,
|
||||
pid: Option<u64>,
|
||||
branch: Option<String>,
|
||||
worktree: Option<String>,
|
||||
original_branch: Option<String>,
|
||||
original_commit: Option<String>,
|
||||
log_path: Option<String>,
|
||||
prompt: Option<String>,
|
||||
model: Option<String>,
|
||||
start_time: Option<u64>,
|
||||
end_time: Option<u64>,
|
||||
state: Option<String>,
|
||||
total_tokens: Option<u64>,
|
||||
input_tokens: Option<u64>,
|
||||
output_tokens: Option<u64>,
|
||||
reasoning_output_tokens: Option<u64>,
|
||||
automerge: Option<bool>,
|
||||
explicit_branch_name: Option<String>,
|
||||
last_update_time: Option<u64>,
|
||||
duration_secs: Option<u64>,
|
||||
}
|
||||
|
||||
pub fn run_inspect(cli: InspectCli) -> anyhow::Result<()> {
|
||||
let id = cli.id.to_lowercase();
|
||||
let tasks = load_task_records()?;
|
||||
let matches: Vec<TaskFull> = tasks
|
||||
.into_iter()
|
||||
.filter(|t| t.task_id.starts_with(&id) || t.branch.as_deref().map(|b| b == id).unwrap_or(false))
|
||||
.collect();
|
||||
if matches.is_empty() {
|
||||
eprintln!("No task matches identifier '{}'.", id);
|
||||
return Ok(());
|
||||
}
|
||||
if matches.len() > 1 {
|
||||
eprintln!("Identifier '{}' is ambiguous; matches: {}", id, matches.iter().map(|m| &m.task_id[..8]).collect::<Vec<_>>().join(", "));
|
||||
return Ok(());
|
||||
}
|
||||
let task = &matches[0];
|
||||
if cli.json {
|
||||
println!("{}", serde_json::to_string_pretty(task)?);
|
||||
return Ok(());
|
||||
}
|
||||
print_human(task);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn base_dir() -> Option<PathBuf> {
|
||||
if let Ok(val) = std::env::var("CODEX_HOME") { if !val.is_empty() { return std::fs::canonicalize(val).ok(); } }
|
||||
let home = std::env::var_os("HOME")?;
|
||||
Some(PathBuf::from(home).join(".codex"))
|
||||
}
|
||||
|
||||
fn load_task_records() -> anyhow::Result<Vec<TaskFull>> {
|
||||
let mut map: std::collections::HashMap<String, TaskFull> = std::collections::HashMap::new();
|
||||
let Some(base) = base_dir() else { return Ok(vec![]); };
|
||||
let tasks = base.join("tasks.jsonl");
|
||||
if !tasks.exists() { return Ok(vec![]); }
|
||||
let f = File::open(tasks)?;
|
||||
let reader = BufReader::new(f);
|
||||
for line in reader.lines() {
|
||||
let Ok(line) = line else { continue };
|
||||
if line.trim().is_empty() { continue; }
|
||||
let Ok(val) = serde_json::from_str::<serde_json::Value>(&line) else { continue };
|
||||
let Ok(rec) = serde_json::from_value::<RawRecord>(val) else { continue };
|
||||
let Some(task_id) = rec.task_id.clone() else { continue };
|
||||
let entry = map.entry(task_id.clone()).or_insert_with(|| TaskFull { task_id: task_id.clone(), ..Default::default() });
|
||||
// Initial metadata fields
|
||||
if rec.start_time.is_some() {
|
||||
entry.pid = rec.pid.or(entry.pid);
|
||||
entry.branch = rec.branch.or(entry.branch.clone());
|
||||
entry.worktree = rec.worktree.or(entry.worktree.clone());
|
||||
entry.original_branch = rec.original_branch.or(entry.original_branch.clone());
|
||||
entry.original_commit = rec.original_commit.or(entry.original_commit.clone());
|
||||
entry.log_path = rec.log_path.or(entry.log_path.clone());
|
||||
entry.prompt = rec.prompt.or(entry.prompt.clone());
|
||||
entry.model = rec.model.or(entry.model.clone());
|
||||
entry.start_time = rec.start_time.or(entry.start_time);
|
||||
entry.automerge = rec.automerge.or(entry.automerge);
|
||||
entry.explicit_branch_name = rec.explicit_branch_name.or(entry.explicit_branch_name.clone());
|
||||
}
|
||||
if let Some(state) = rec.state { entry.state = Some(state); }
|
||||
if rec.update_time.is_some() { entry.last_update_time = rec.update_time; }
|
||||
if rec.end_time.is_some() || rec.completion_time.is_some() {
|
||||
entry.end_time = rec.end_time.or(rec.completion_time).or(entry.end_time);
|
||||
}
|
||||
if let Some(tc) = rec.token_count.as_ref() {
|
||||
if let Some(total) = tc.get("total_tokens").and_then(|v| v.as_u64()) { entry.total_tokens = Some(total); }
|
||||
if let Some(inp) = tc.get("input_tokens").and_then(|v| v.as_u64()) { entry.input_tokens = Some(inp); }
|
||||
if let Some(out) = tc.get("output_tokens").and_then(|v| v.as_u64()) { entry.output_tokens = Some(out); }
|
||||
if let Some(rout) = tc.get("reasoning_output_tokens").and_then(|v| v.as_u64()) { entry.reasoning_output_tokens = Some(rout); }
|
||||
}
|
||||
}
|
||||
// Compute duration
|
||||
for t in map.values_mut() {
|
||||
if let (Some(s), Some(e)) = (t.start_time, t.end_time) { t.duration_secs = Some(e.saturating_sub(s)); }
|
||||
}
|
||||
Ok(map.into_values().collect())
|
||||
}
|
||||
|
||||
fn print_human(task: &TaskFull) {
|
||||
println!("Task {}", task.task_id);
|
||||
println!("State: {}", task.state.as_deref().unwrap_or("?"));
|
||||
if let Some(model) = &task.model { println!("Model: {}", model); } else { println!("Model: {}", resolve_default_model()); }
|
||||
if let Some(branch) = &task.branch { println!("Branch: {}", branch); }
|
||||
if let Some(wt) = &task.worktree { println!("Worktree: {}", wt); }
|
||||
if let Some(ob) = &task.original_branch { println!("Original branch: {}", ob); }
|
||||
if let Some(oc) = &task.original_commit { println!("Original commit: {}", oc); }
|
||||
if let Some(start) = task.start_time { println!("Start: {}", format_epoch(start)); }
|
||||
if let Some(end) = task.end_time { println!("End: {}", format_epoch(end)); }
|
||||
if let Some(d) = task.duration_secs { println!("Duration: {}s", d); }
|
||||
if let Some(pid) = task.pid { println!("PID: {}", pid); }
|
||||
if let Some(log) = &task.log_path { println!("Log: {}", log); }
|
||||
if let Some(am) = task.automerge { println!("Automerge: {}", am); }
|
||||
if let Some(exp) = &task.explicit_branch_name { println!("Explicit branch name: {}", exp); }
|
||||
if let Some(total) = task.total_tokens { println!("Total tokens: {}", total); }
|
||||
if task.input_tokens.is_some() || task.output_tokens.is_some() {
|
||||
println!(" Input: {:?} Output: {:?} Reasoning: {:?}", task.input_tokens, task.output_tokens, task.reasoning_output_tokens);
|
||||
}
|
||||
if let Some(p) = &task.prompt { println!("Prompt:\n{}", p); }
|
||||
}
|
||||
|
||||
fn format_epoch(secs: u64) -> String {
|
||||
use chrono::{TimeZone, Utc};
|
||||
if let Some(dt) = Utc.timestamp_opt(secs as i64, 0).single() { dt.to_rfc3339() } else { secs.to_string() }
|
||||
}
|
||||
|
||||
fn resolve_default_model() -> String {
|
||||
if let Some(base) = base_dir() {
|
||||
let candidates = ["config.json", "config.yaml", "config.yml"];
|
||||
for name in candidates {
|
||||
let p = base.join(name);
|
||||
if p.exists() {
|
||||
if let Ok(raw) = fs::read_to_string(&p) {
|
||||
if name.ends_with(".json") {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&raw) {
|
||||
if let Some(m) = v.get("model").and_then(|x| x.as_str()) { if !m.trim().is_empty() { return m.to_string(); } }
|
||||
}
|
||||
} else {
|
||||
for line in raw.lines() { if let Some(rest) = line.trim().strip_prefix("model:") { let val = rest.trim().trim_matches('"'); if !val.is_empty() { return val.to_string(); } } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"codex-mini-latest".to_string()
|
||||
}
|
||||
@@ -1,11 +1,7 @@
|
||||
pub mod concurrent;
|
||||
pub mod debug_sandbox;
|
||||
mod exit_status;
|
||||
pub mod login;
|
||||
pub mod proto;
|
||||
pub mod tasks;
|
||||
pub mod logs;
|
||||
pub mod inspect;
|
||||
|
||||
use clap::Parser;
|
||||
use codex_common::CliConfigOverrides;
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
use clap::Parser;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Read, Seek, SeekFrom};
|
||||
use std::path::PathBuf;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
pub struct LogsCli {
|
||||
/// Task identifier: full/short task UUID or branch name
|
||||
pub id: String,
|
||||
/// Follow log output (stream new lines)
|
||||
#[arg(short = 'f', long = "follow")]
|
||||
pub follow: bool,
|
||||
/// Show only the last N lines (like tail -n). If omitted, show full file.
|
||||
#[arg(short = 'n', long = "lines")]
|
||||
pub lines: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RawRecord {
|
||||
task_id: Option<String>,
|
||||
branch: Option<String>,
|
||||
log_path: Option<String>,
|
||||
start_time: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TaskMeta {
|
||||
task_id: String,
|
||||
branch: Option<String>,
|
||||
log_path: String,
|
||||
start_time: Option<u64>,
|
||||
}
|
||||
|
||||
pub fn run_logs(cli: LogsCli) -> anyhow::Result<()> {
|
||||
let id = cli.id.to_lowercase();
|
||||
let tasks = load_tasks_index()?;
|
||||
if tasks.is_empty() {
|
||||
eprintln!("No tasks found in tasks.jsonl");
|
||||
return Ok(());
|
||||
}
|
||||
let matches: Vec<&TaskMeta> = tasks
|
||||
.values()
|
||||
.filter(|meta| {
|
||||
meta.task_id.starts_with(&id) || meta.branch.as_deref().map(|b| b == id).unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
if matches.is_empty() {
|
||||
eprintln!("No task matches identifier '{}'.", id);
|
||||
return Ok(());
|
||||
}
|
||||
if matches.len() > 1 {
|
||||
eprintln!("Identifier '{}' is ambiguous; matches: {}", id, matches.iter().map(|m| &m.task_id[..8]).collect::<Vec<_>>().join(", "));
|
||||
return Ok(());
|
||||
}
|
||||
let task = matches[0];
|
||||
let path = PathBuf::from(&task.log_path);
|
||||
if !path.exists() {
|
||||
eprintln!("Log file not found at {}", path.display());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if cli.follow {
|
||||
tail_file(&path, cli.lines)?;
|
||||
} else {
|
||||
print_file(&path, cli.lines)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn base_dir() -> Option<PathBuf> {
|
||||
if let Ok(val) = std::env::var("CODEX_HOME") { if !val.is_empty() { return std::fs::canonicalize(val).ok(); } }
|
||||
let home = std::env::var_os("HOME")?;
|
||||
Some(PathBuf::from(home).join(".codex"))
|
||||
}
|
||||
|
||||
fn load_tasks_index() -> anyhow::Result<HashMap<String, TaskMeta>> {
|
||||
let mut map: HashMap<String, TaskMeta> = HashMap::new();
|
||||
let Some(base) = base_dir() else { return Ok(map); };
|
||||
let tasks = base.join("tasks.jsonl");
|
||||
if !tasks.exists() { return Ok(map); }
|
||||
let f = File::open(tasks)?;
|
||||
let reader = BufReader::new(f);
|
||||
for line in reader.lines() {
|
||||
let Ok(line) = line else { continue };
|
||||
if line.trim().is_empty() { continue; }
|
||||
let Ok(val) = serde_json::from_str::<serde_json::Value>(&line) else { continue };
|
||||
let Ok(rec) = serde_json::from_value::<RawRecord>(val) else { continue };
|
||||
let (Some(task_id), Some(log_path)) = (rec.task_id.clone(), rec.log_path.clone()) else { continue };
|
||||
// Insert or update only if not already present (we just need initial metadata)
|
||||
map.entry(task_id.clone()).or_insert(TaskMeta {
|
||||
task_id,
|
||||
branch: rec.branch,
|
||||
log_path,
|
||||
start_time: rec.start_time,
|
||||
});
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
fn print_file(path: &PathBuf, last_lines: Option<usize>) -> anyhow::Result<()> {
|
||||
if let Some(n) = last_lines {
|
||||
let f = File::open(path)?;
|
||||
let reader = BufReader::new(f);
|
||||
let mut buf: std::collections::VecDeque<String> = std::collections::VecDeque::with_capacity(n);
|
||||
for line in reader.lines() {
|
||||
if let Ok(l) = line { if buf.len() == n { buf.pop_front(); } buf.push_back(l); }
|
||||
}
|
||||
for l in buf { println!("{}", l); }
|
||||
return Ok(());
|
||||
}
|
||||
// Full file
|
||||
let mut f = File::open(path)?;
|
||||
let mut contents = String::new();
|
||||
f.read_to_string(&mut contents)?;
|
||||
print!("{}", contents);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tail_file(path: &PathBuf, last_lines: Option<usize>) -> anyhow::Result<()> {
|
||||
use std::io::{self};
|
||||
// Initial output
|
||||
if let Some(n) = last_lines { print_file(path, Some(n))?; } else { print_file(path, None)?; }
|
||||
let mut f = File::open(path)?;
|
||||
let mut pos = f.metadata()?.len();
|
||||
loop {
|
||||
thread::sleep(Duration::from_millis(500));
|
||||
let meta = match f.metadata() { Ok(m) => m, Err(_) => break };
|
||||
let len = meta.len();
|
||||
if len < pos { // truncated
|
||||
pos = 0;
|
||||
}
|
||||
if len > pos {
|
||||
f.seek(SeekFrom::Start(pos))?;
|
||||
let mut buf = String::new();
|
||||
f.read_to_string(&mut buf)?;
|
||||
if !buf.is_empty() { print!("{}", buf); io::Write::flush(&mut std::io::stdout())?; }
|
||||
pos = len;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -4,7 +4,6 @@ use clap_complete::Shell;
|
||||
use clap_complete::generate;
|
||||
use codex_chatgpt::apply_command::ApplyCommand;
|
||||
use codex_chatgpt::apply_command::run_apply_command;
|
||||
use codex_cli::concurrent::maybe_spawn_concurrent;
|
||||
use codex_cli::LandlockCommand;
|
||||
use codex_cli::SeatbeltCommand;
|
||||
use codex_cli::login::run_login_with_chatgpt;
|
||||
@@ -33,25 +32,6 @@ struct MultitoolCli {
|
||||
#[clap(flatten)]
|
||||
interactive: TuiCli,
|
||||
|
||||
/// Autonomous mode: run the command in the background & concurrently using a git worktree.
|
||||
/// Requires the current directory (or --cd provided path) to be a git repository.
|
||||
#[clap(long)]
|
||||
concurrent: bool,
|
||||
|
||||
/// Control whether the concurrent run auto-merges the worktree branch back into the original branch.
|
||||
/// Defaults to true (may also be set via CONCURRENT_AUTOMERGE env var).
|
||||
#[clap(long = "concurrent-automerge", value_name = "BOOL")]
|
||||
concurrent_automerge: Option<bool>,
|
||||
|
||||
/// Explicit branch name to use for the concurrent worktree instead of the default `codex/<slug>`.
|
||||
/// May also be set via CONCURRENT_BRANCH_NAME env var.
|
||||
#[clap(long = "concurrent-branch-name", value_name = "BRANCH")]
|
||||
concurrent_branch_name: Option<String>,
|
||||
|
||||
/// Best-of-n: run n concurrent worktrees (1-4) and let user pick the best result. Implies --concurrent and disables automerge.
|
||||
#[clap(long = "best-of-n", short = 'n', value_name = "N", default_value_t = 1)]
|
||||
pub best_of_n: u8,
|
||||
|
||||
#[clap(subcommand)]
|
||||
subcommand: Option<Subcommand>,
|
||||
}
|
||||
@@ -81,15 +61,6 @@ enum Subcommand {
|
||||
/// Apply the latest diff produced by Codex agent as a `git apply` to your local working tree.
|
||||
#[clap(visible_alias = "a")]
|
||||
Apply(ApplyCommand),
|
||||
|
||||
/// Manage / inspect concurrent background tasks.
|
||||
Tasks(codex_cli::tasks::TasksCli),
|
||||
|
||||
/// Show or follow logs for a specific task.
|
||||
Logs(codex_cli::logs::LogsCli),
|
||||
|
||||
/// Inspect full metadata for a task.
|
||||
Inspect(codex_cli::inspect::InspectCli),
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
@@ -133,64 +104,8 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
match cli.subcommand {
|
||||
None => {
|
||||
let mut tui_cli = cli.interactive;
|
||||
let root_raw_overrides = cli.config_overrides.raw_overrides.clone();
|
||||
prepend_config_flags(&mut tui_cli.config_overrides, cli.config_overrides);
|
||||
// Best-of-n logic
|
||||
if cli.best_of_n > 1 {
|
||||
let n = cli.best_of_n.min(4).max(1);
|
||||
let mut spawned_any = false;
|
||||
let base_branch = if let Some(ref name) = cli.concurrent_branch_name {
|
||||
name.trim().to_string()
|
||||
} else {
|
||||
// Derive slug from prompt (copied from maybe_spawn_concurrent)
|
||||
let raw_prompt = tui_cli.prompt.as_deref().unwrap_or("");
|
||||
let snippet = raw_prompt.chars().take(32).collect::<String>();
|
||||
let mut slug: String = snippet
|
||||
.chars()
|
||||
.map(|c| if c.is_ascii_alphanumeric() { c.to_ascii_lowercase() } else { '-' })
|
||||
.collect();
|
||||
while slug.contains("--") { slug = slug.replace("--", "-"); }
|
||||
slug = slug.trim_matches('-').to_string();
|
||||
if slug.is_empty() { slug = "prompt".into(); }
|
||||
format!("codex/{}", slug)
|
||||
};
|
||||
for i in 1..=n {
|
||||
let mut tui_cli_n = tui_cli.clone();
|
||||
// Suffix branch name with -01, -02, etc.
|
||||
let branch_name = format!("{}-{:02}", base_branch, i);
|
||||
let branch_name_opt = Some(branch_name);
|
||||
// Always automerge = false for best-of-n
|
||||
match maybe_spawn_concurrent(
|
||||
&mut tui_cli_n,
|
||||
&root_raw_overrides,
|
||||
true, // force concurrent
|
||||
Some(false),
|
||||
&branch_name_opt,
|
||||
) {
|
||||
Ok(true) => { spawned_any = true; },
|
||||
Ok(false) => {},
|
||||
Err(e) => { eprintln!("Error spawning best-of-n run {}: {e}", i); },
|
||||
}
|
||||
}
|
||||
if !spawned_any {
|
||||
codex_tui::run_main(tui_cli, codex_linux_sandbox_exe)?;
|
||||
}
|
||||
// If any spawned, do not run TUI (user will see task IDs)
|
||||
} else {
|
||||
// Attempt concurrent background spawn; if it returns true we skip launching the TUI.
|
||||
if let Ok(spawned) = maybe_spawn_concurrent(
|
||||
&mut tui_cli,
|
||||
&root_raw_overrides,
|
||||
cli.concurrent,
|
||||
cli.concurrent_automerge,
|
||||
&cli.concurrent_branch_name,
|
||||
) {
|
||||
if !spawned { codex_tui::run_main(tui_cli, codex_linux_sandbox_exe)?; }
|
||||
} else {
|
||||
// On error fallback to interactive.
|
||||
codex_tui::run_main(tui_cli, codex_linux_sandbox_exe)?;
|
||||
}
|
||||
}
|
||||
codex_tui::run_main(tui_cli, codex_linux_sandbox_exe)?;
|
||||
}
|
||||
Some(Subcommand::Exec(mut exec_cli)) => {
|
||||
prepend_config_flags(&mut exec_cli.config_overrides, cli.config_overrides);
|
||||
@@ -232,15 +147,6 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
||||
prepend_config_flags(&mut apply_cli.config_overrides, cli.config_overrides);
|
||||
run_apply_command(apply_cli).await?;
|
||||
}
|
||||
Some(Subcommand::Tasks(tasks_cli)) => {
|
||||
codex_cli::tasks::run_tasks(tasks_cli)?;
|
||||
}
|
||||
Some(Subcommand::Logs(logs_cli)) => {
|
||||
codex_cli::logs::run_logs(logs_cli)?;
|
||||
}
|
||||
Some(Subcommand::Inspect(inspect_cli)) => {
|
||||
codex_cli::inspect::run_inspect(inspect_cli)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -35,7 +35,7 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> {
|
||||
|
||||
let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?;
|
||||
let ctrl_c = notify_on_sigint();
|
||||
let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c.clone()).await?;
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await?;
|
||||
let codex = Arc::new(codex);
|
||||
|
||||
// Task that reads JSON lines from stdin and forwards to Submission Queue
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::fs;
|
||||
use chrono::Local;
|
||||
use codex_common::elapsed::format_duration;
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
pub struct TasksCli {
|
||||
#[command(subcommand)]
|
||||
pub cmd: TasksCommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum TasksCommand {
|
||||
/// List background concurrent tasks (from ~/.codex/tasks.jsonl)
|
||||
Ls(TasksListArgs),
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
pub struct TasksListArgs {
|
||||
/// Output raw JSON instead of table
|
||||
#[arg(long)]
|
||||
pub json: bool,
|
||||
/// Limit number of tasks displayed (most recent first)
|
||||
#[arg(long)]
|
||||
pub limit: Option<usize>,
|
||||
/// Show completed tasks as well (by default only running tasks)
|
||||
#[arg(short = 'a', long = "all")]
|
||||
pub all: bool,
|
||||
/// Show all columns including prompt text
|
||||
#[arg(long = "all-columns")]
|
||||
pub all_columns: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RawRecord {
|
||||
task_id: Option<String>,
|
||||
pid: Option<u64>,
|
||||
worktree: Option<String>,
|
||||
branch: Option<String>,
|
||||
original_branch: Option<String>,
|
||||
original_commit: Option<String>,
|
||||
log_path: Option<String>,
|
||||
prompt: Option<String>,
|
||||
model: Option<String>,
|
||||
start_time: Option<u64>,
|
||||
update_time: Option<u64>,
|
||||
token_count: Option<serde_json::Value>,
|
||||
state: Option<String>,
|
||||
completion_time: Option<u64>,
|
||||
end_time: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
struct TaskAggregate {
|
||||
task_id: String,
|
||||
pid: Option<u64>,
|
||||
branch: Option<String>,
|
||||
worktree: Option<String>,
|
||||
prompt: Option<String>,
|
||||
model: Option<String>,
|
||||
start_time: Option<u64>,
|
||||
last_update_time: Option<u64>,
|
||||
total_tokens: Option<u64>,
|
||||
state: Option<String>,
|
||||
end_time: Option<u64>,
|
||||
}
|
||||
|
||||
pub fn run_tasks(cmd: TasksCli) -> anyhow::Result<()> {
|
||||
match cmd.cmd {
|
||||
TasksCommand::Ls(args) => list_tasks(args),
|
||||
}
|
||||
}
|
||||
|
||||
fn base_dir() -> Option<std::path::PathBuf> {
|
||||
if let Ok(val) = std::env::var("CODEX_HOME") { if !val.is_empty() { return std::fs::canonicalize(val).ok(); } }
|
||||
let home = std::env::var_os("HOME")?;
|
||||
let base = std::path::PathBuf::from(home).join(".codex");
|
||||
Some(base)
|
||||
}
|
||||
|
||||
fn list_tasks(args: TasksListArgs) -> anyhow::Result<()> {
|
||||
let Some(base) = base_dir() else {
|
||||
println!("No home directory found; cannot locate tasks.jsonl");
|
||||
return Ok(());
|
||||
};
|
||||
let path = base.join("tasks.jsonl");
|
||||
if !path.exists() {
|
||||
println!("No tasks.jsonl found (no concurrent tasks recorded yet)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let f = File::open(&path)?;
|
||||
let reader = BufReader::new(f);
|
||||
|
||||
let mut agg: HashMap<String, TaskAggregate> = HashMap::new();
|
||||
for line_res in reader.lines() {
|
||||
let line = match line_res { Ok(l) => l, Err(_) => continue };
|
||||
if line.trim().is_empty() { continue; }
|
||||
let raw: serde_json::Value = match serde_json::from_str(&line) { Ok(v) => v, Err(_) => continue };
|
||||
let rec: RawRecord = match serde_json::from_value(raw) { Ok(r) => r, Err(_) => continue };
|
||||
let Some(task_id) = rec.task_id.clone() else { continue }; // must have task_id
|
||||
let entry = agg.entry(task_id.clone()).or_insert_with(|| TaskAggregate { task_id: task_id.clone(), ..Default::default() });
|
||||
if rec.start_time.is_some() { // initial metadata line
|
||||
entry.pid = rec.pid.or(entry.pid);
|
||||
entry.branch = rec.branch.or(entry.branch.clone());
|
||||
entry.worktree = rec.worktree.or(entry.worktree.clone());
|
||||
entry.prompt = rec.prompt.or(entry.prompt.clone());
|
||||
entry.model = rec.model.or(entry.model.clone());
|
||||
entry.start_time = rec.start_time.or(entry.start_time);
|
||||
}
|
||||
if let Some(tc_val) = rec.token_count.as_ref() { if tc_val.is_object() { if let Some(total) = tc_val.get("total_tokens").and_then(|v| v.as_u64()) { entry.total_tokens = Some(total); } } }
|
||||
if rec.update_time.is_some() { entry.last_update_time = rec.update_time; }
|
||||
if let Some(state) = rec.state { entry.state = Some(state); }
|
||||
if rec.completion_time.is_some() || rec.end_time.is_some() {
|
||||
entry.end_time = rec.end_time.or(rec.completion_time).or(entry.end_time);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect and sort by start_time desc
|
||||
let mut tasks: Vec<TaskAggregate> = agg.into_values().collect();
|
||||
tasks.sort_by_key(|j| std::cmp::Reverse(j.start_time.unwrap_or(0)));
|
||||
|
||||
if !args.all { tasks.retain(|j| j.state.as_deref() != Some("done")); }
|
||||
if let Some(limit) = args.limit { tasks.truncate(limit); }
|
||||
|
||||
if args.json {
|
||||
println!("{}", serde_json::to_string_pretty(&tasks)?);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if tasks.is_empty() {
|
||||
println!("No tasks found");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Table header
|
||||
if args.all_columns {
|
||||
println!("\x1b[1m{:<8} {:>6} {:<22} {:<12} {:<8} {:>8} {:<12} {}\x1b[0m", "TASK_ID", "PID", "BRANCH", "START", "STATE", "TOKENS", "MODEL", "PROMPT");
|
||||
} else {
|
||||
// Widened branch column to 22 chars for better readability.
|
||||
println!("\x1b[1m{:<8} {:>6} {:<22} {:<12} {:<8} {:>8} {:<12}\x1b[0m", "TASK_ID", "PID", "BRANCH", "START", "STATE", "TOKENS", "MODEL");
|
||||
}
|
||||
for t in tasks {
|
||||
let task_short = if t.task_id.len() > 8 { &t.task_id[..8] } else { &t.task_id };
|
||||
let pid_str = t.pid.map(|p| p.to_string()).unwrap_or_default();
|
||||
let mut branch = t.branch.clone().unwrap_or_default();
|
||||
let branch_limit = if args.all_columns { 22 } else { 22 }; // unified width
|
||||
if branch.len() > branch_limit { branch.truncate(branch_limit); }
|
||||
let start = t.start_time.map(|start_secs| {
|
||||
let now = Local::now().timestamp() as u64;
|
||||
if now > start_secs {
|
||||
let elapsed = std::time::Duration::from_secs(now - start_secs);
|
||||
format!("{} ago", format_duration(elapsed))
|
||||
} else {
|
||||
"just now".to_string()
|
||||
}
|
||||
}).unwrap_or_default();
|
||||
let tokens = t.total_tokens.map(|t| t.to_string()).unwrap_or_default();
|
||||
let state = t.state.clone().unwrap_or_else(|| "?".into());
|
||||
let mut model = t.model.clone().unwrap_or_default();
|
||||
if model.trim().is_empty() { model = resolve_default_model(); }
|
||||
if model.is_empty() { model.push('-'); }
|
||||
if model.len() > 12 { model.truncate(12); }
|
||||
if args.all_columns {
|
||||
let mut prompt = t.prompt.clone().unwrap_or_default().replace('\n', " ");
|
||||
if prompt.len() > 60 { prompt.truncate(60); }
|
||||
println!("{:<8} {:>6} {:<22} {:<12} {:<8} {:>8} {:<12} {}", task_short, pid_str, branch, start, state, tokens, model, prompt);
|
||||
} else {
|
||||
println!("{:<8} {:>6} {:<22} {:<12} {:<8} {:>8} {:<12}", task_short, pid_str, branch, start, state, tokens, model);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn resolve_default_model() -> String {
|
||||
// Attempt to read config json/yaml for model, otherwise fallback to hardcoded default.
|
||||
if let Some(base) = base_dir() {
|
||||
let candidates = ["config.json", "config.yaml", "config.yml"];
|
||||
for name in candidates {
|
||||
let p = base.join(name);
|
||||
if p.exists() {
|
||||
if let Ok(raw) = fs::read_to_string(&p) {
|
||||
// Try JSON first.
|
||||
if name.ends_with(".json") {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&raw) {
|
||||
if let Some(m) = v.get("model").and_then(|x| x.as_str()) {
|
||||
if !m.trim().is_empty() { return m.to_string(); }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Very lightweight YAML parse: look for line starting with model:
|
||||
for line in raw.lines() {
|
||||
if let Some(rest) = line.trim().strip_prefix("model:") {
|
||||
let val = rest.trim().trim_matches('"');
|
||||
if !val.is_empty() {
|
||||
return val.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback default agentic model used elsewhere.
|
||||
"codex-mini-latest".to_string()
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
// Minimal integration test for --concurrent background spawning.
|
||||
// Verifies that invoking the top-level CLI with --concurrent records a task entry
|
||||
// in CODEX_HOME/tasks.jsonl and that multiple invocations append distinct task_ids.
|
||||
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::process::Command;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tempfile::TempDir;
|
||||
|
||||
// Skip helper when sandbox network disabled (mirrors existing tests' behavior).
|
||||
fn network_disabled() -> bool {
|
||||
std::env::var(codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concurrent_creates_task_records() {
|
||||
if network_disabled() {
|
||||
eprintln!("Skipping concurrent_creates_task_records due to sandbox network-disabled env");
|
||||
return;
|
||||
}
|
||||
|
||||
// Temp home (CODEX_HOME) and separate temp git repo.
|
||||
let home = TempDir::new().expect("temp home");
|
||||
let repo = TempDir::new().expect("temp repo");
|
||||
|
||||
// Initialize a minimal git repository (needed for --concurrent worktree logic).
|
||||
assert!(Command::new("git").arg("init").current_dir(repo.path()).status().unwrap().success());
|
||||
fs::write(repo.path().join("README.md"), "# temp\n").unwrap();
|
||||
assert!(Command::new("git").arg("add").arg(".").current_dir(repo.path()).status().unwrap().success());
|
||||
assert!(Command::new("git")
|
||||
.args(["commit", "-m", "init"]) // may warn about user/email; allow non-zero if commit already exists
|
||||
.current_dir(repo.path())
|
||||
.status()
|
||||
.map(|s| s.success())
|
||||
.unwrap_or(true));
|
||||
|
||||
// SSE fixture so the spawned background exec does not perform a real network call.
|
||||
let fixture = home.path().join("fixture.sse");
|
||||
let mut f = fs::File::create(&fixture).unwrap();
|
||||
writeln!(f, "data: {{\"choices\":[{{\"delta\":{{\"content\":\"ok\"}}}}]}}\n").unwrap();
|
||||
writeln!(f, "data: {{\"choices\":[{{\"delta\":{{}}}}]}}\n").unwrap();
|
||||
writeln!(f, "data: [DONE]\n").unwrap();
|
||||
|
||||
// Helper to run one concurrent invocation with a given prompt.
|
||||
let run_once = |prompt: &str| {
|
||||
let mut cmd = Command::new("cargo");
|
||||
cmd.arg("run")
|
||||
.arg("-p")
|
||||
.arg("codex-cli")
|
||||
.arg("--quiet")
|
||||
.arg("--")
|
||||
.arg("--concurrent")
|
||||
.arg("--full-auto")
|
||||
.arg("-C")
|
||||
.arg(repo.path())
|
||||
.arg(prompt);
|
||||
cmd.env("CODEX_HOME", home.path())
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||
.env("OPENAI_BASE_URL", "http://unused.local");
|
||||
let output = cmd.output().expect("spawn codex");
|
||||
assert!(output.status.success(), "concurrent codex run failed: stderr={}", String::from_utf8_lossy(&output.stderr));
|
||||
};
|
||||
|
||||
run_once("Add a cat in ASCII");
|
||||
run_once("Add hello world comment");
|
||||
|
||||
// Wait for tasks.jsonl to contain at least two lines with task records.
|
||||
let tasks_path = home.path().join("tasks.jsonl");
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
let mut lines: Vec<String> = Vec::new();
|
||||
while Instant::now() < deadline {
|
||||
if tasks_path.exists() {
|
||||
let content = fs::read_to_string(&tasks_path).unwrap_or_default();
|
||||
lines = content.lines().filter(|l| !l.trim().is_empty()).map(|s| s.to_string()).collect();
|
||||
if lines.len() >= 2 { break; }
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
}
|
||||
assert!(lines.len() >= 2, "Expected at least 2 task records, got {}", lines.len());
|
||||
|
||||
// Parse JSON and ensure distinct task_ids and prompts present.
|
||||
let mut task_ids = std::collections::HashSet::new();
|
||||
let mut saw_cat = false;
|
||||
let mut saw_hello = false;
|
||||
for line in &lines {
|
||||
if let Ok(val) = serde_json::from_str::<serde_json::Value>(line) {
|
||||
if let Some(tid) = val.get("task_id").and_then(|v| v.as_str()) { task_ids.insert(tid.to_string()); }
|
||||
if let Some(p) = val.get("prompt").and_then(|v| v.as_str()) {
|
||||
if p.contains("cat") { saw_cat = true; }
|
||||
if p.contains("hello") { saw_hello = true; }
|
||||
}
|
||||
assert_eq!(val.get("state").and_then(|v| v.as_str()), Some("started"), "task record missing started state");
|
||||
}
|
||||
}
|
||||
assert!(task_ids.len() >= 2, "Expected distinct task_ids, got {:?}", task_ids);
|
||||
assert!(saw_cat, "Did not find cat prompt in tasks.jsonl");
|
||||
assert!(saw_hello, "Did not find hello prompt in tasks.jsonl");
|
||||
}
|
||||
223
codex-rs/cli/tests/integration.rs
Normal file
223
codex-rs/cli/tests/integration.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
//! End-to-end integration tests for the `codex` CLI.
|
||||
//!
|
||||
//! These spin up a local [`wiremock`][] server to stand in for the MCP server
|
||||
//! and then run the real compiled `codex` binary against it. The goal is to
|
||||
//! verify the high-level request/response flow rather than the details of the
|
||||
//! individual async functions.
|
||||
//!
|
||||
//! [`wiremock`]: https://docs.rs/wiremock
|
||||
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use predicates::prelude::*;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
// ----- tests -----
|
||||
|
||||
/// Sends a single simple prompt and verifies that the streamed response is
|
||||
/// surfaced to the user. This exercises the most common "ask a question, get a
|
||||
/// textual answer" flow.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn full_conversation_turn_integration() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!("Skipping test because network is disabled");
|
||||
return;
|
||||
}
|
||||
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_message("Hello, world."), "text/event-stream"),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Disable retries — the mock server will fail hard if we make an unexpected
|
||||
// request, so retries only slow the test down.
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
|
||||
}
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let sandbox = TempDir::new().unwrap();
|
||||
write_config(codex_home.path(), &server);
|
||||
|
||||
// Capture the agent's final message in a file so we can assert on it precisely.
|
||||
let last_message_file = sandbox.path().join("last_message.txt");
|
||||
|
||||
let mut cmd = assert_cmd::Command::cargo_bin("codex").unwrap();
|
||||
cmd.env("CODEX_HOME", codex_home.path())
|
||||
.current_dir(sandbox.path())
|
||||
.arg("exec")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("--output-last-message")
|
||||
.arg(&last_message_file)
|
||||
.arg("Hello");
|
||||
|
||||
cmd.assert()
|
||||
.success()
|
||||
.stdout(predicate::str::contains("Hello, world."));
|
||||
|
||||
// Assert on the captured last message file (more robust than stdout formatting).
|
||||
let last = fs::read_to_string(&last_message_file).unwrap();
|
||||
let expected = "Hello, world.";
|
||||
assert_eq!(last.trim(), expected);
|
||||
}
|
||||
|
||||
/// Simulates a tool invocation (`shell`) followed by a second assistant message
|
||||
/// once the tool call completes.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn tool_invocation_flow() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!("Skipping test because network is disabled");
|
||||
return;
|
||||
}
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// The first request returns a function-call item; the second returns the
|
||||
// final assistant message. Use an atomic counter to serve them in order.
|
||||
struct SeqResponder {
|
||||
count: std::sync::atomic::AtomicUsize,
|
||||
}
|
||||
impl wiremock::Respond for SeqResponder {
|
||||
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
||||
use std::sync::atomic::Ordering;
|
||||
match self.count.fetch_add(1, Ordering::SeqCst) {
|
||||
0 => ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_function_call(), "text/event-stream"),
|
||||
_ => ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_final_after_call(), "text/event-stream"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(SeqResponder {
|
||||
count: std::sync::atomic::AtomicUsize::new(0),
|
||||
})
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
|
||||
}
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let sandbox = TempDir::new().unwrap();
|
||||
write_config(codex_home.path(), &server);
|
||||
|
||||
// Capture final assistant message after tool invocation.
|
||||
let last_message_file = sandbox.path().join("last_message.txt");
|
||||
|
||||
let mut cmd = assert_cmd::Command::cargo_bin("codex").unwrap();
|
||||
cmd.env("CODEX_HOME", codex_home.path())
|
||||
.current_dir(sandbox.path())
|
||||
.arg("exec")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("--output-last-message")
|
||||
.arg(&last_message_file)
|
||||
.arg("Run shell");
|
||||
|
||||
cmd.assert()
|
||||
.success()
|
||||
.stdout(predicate::str::contains("exec echo hi"))
|
||||
.stdout(predicate::str::contains("hi"));
|
||||
|
||||
// Assert that the final assistant message (second response) was 'done'.
|
||||
let last = fs::read_to_string(&last_message_file).unwrap();
|
||||
let expected = "done";
|
||||
assert_eq!(last.trim(), expected);
|
||||
}
|
||||
|
||||
/// Write a minimal `config.toml` pointing the CLI at the mock server.
|
||||
fn write_config(codex_home: &Path, server: &MockServer) {
|
||||
fs::write(
|
||||
codex_home.join("config.toml"),
|
||||
format!(
|
||||
r#"
|
||||
model_provider = "mock"
|
||||
model = "test-model"
|
||||
|
||||
[model_providers.mock]
|
||||
name = "mock"
|
||||
base_url = "{}/v1"
|
||||
env_key = "PATH"
|
||||
wire_api = "responses"
|
||||
"#,
|
||||
server.uri()
|
||||
),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Small helper to generate an SSE stream with a single assistant message.
|
||||
fn sse_message(text: &str) -> String {
|
||||
const TEMPLATE: &str = r#"event: response.output_item.done
|
||||
data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"TEXT_PLACEHOLDER"}]}}
|
||||
|
||||
event: response.completed
|
||||
data: {"type":"response.completed","response":{"id":"resp1","output":[]}}
|
||||
|
||||
|
||||
"#;
|
||||
|
||||
TEMPLATE.replace("TEXT_PLACEHOLDER", text)
|
||||
}
|
||||
|
||||
/// Helper to craft an SSE stream that returns a `function_call`.
|
||||
fn sse_function_call() -> String {
|
||||
let call = serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "function_call",
|
||||
"name": "shell",
|
||||
"arguments": "{\"command\":[\"echo\",\"hi\"]}",
|
||||
"call_id": "call1"
|
||||
}
|
||||
});
|
||||
let completed = serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"response": {"id": "resp1", "output": []}
|
||||
});
|
||||
|
||||
format!(
|
||||
"event: response.output_item.done\ndata: {call}\n\n\
|
||||
event: response.completed\ndata: {completed}\n\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
/// SSE stream for the assistant's final message after the tool call returns.
|
||||
fn sse_final_after_call() -> String {
|
||||
let msg = serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "done"}]}
|
||||
});
|
||||
let completed = serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"response": {"id": "resp2", "output": []}
|
||||
});
|
||||
|
||||
format!(
|
||||
"event: response.output_item.done\ndata: {msg}\n\n\
|
||||
event: response.completed\ndata: {completed}\n\n\n"
|
||||
)
|
||||
}
|
||||
@@ -64,11 +64,7 @@ impl CliConfigOverrides {
|
||||
// `-c model=o3` without the quotes.
|
||||
let value: Value = match parse_toml_value(value_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
// Strip leading/trailing quotes if present
|
||||
let trimmed = value_str.trim().trim_matches(|c| c == '"' || c == '\'');
|
||||
Value::String(trimmed.to_string())
|
||||
}
|
||||
Err(_) => Value::String(value_str.to_string()),
|
||||
};
|
||||
|
||||
Ok((key.to_string(), value))
|
||||
|
||||
@@ -22,8 +22,7 @@ fn format_elapsed_millis(millis: i64) -> String {
|
||||
if millis < 1000 {
|
||||
format!("{millis}ms")
|
||||
} else if millis < 60_000 {
|
||||
let secs = millis / 1000;
|
||||
format!("{secs}s")
|
||||
format!("{:.2}s", millis as f64 / 1000.0)
|
||||
} else {
|
||||
let minutes = millis / 60_000;
|
||||
let seconds = (millis % 60_000) / 1000;
|
||||
@@ -49,12 +48,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_format_duration_seconds() {
|
||||
// Durations between 1s (inclusive) and 60s (exclusive) should be
|
||||
// printed as whole seconds.
|
||||
// printed with 2-decimal-place seconds.
|
||||
let dur = Duration::from_millis(1_500); // 1.5s
|
||||
assert_eq!(format_duration(dur), "1s");
|
||||
assert_eq!(format_duration(dur), "1.50s");
|
||||
|
||||
// 59.999s rounds to 60.00s
|
||||
let dur2 = Duration::from_millis(59_999);
|
||||
assert_eq!(format_duration(dur2), "59s");
|
||||
assert_eq!(format_duration(dur2), "60.00s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -92,32 +92,6 @@ http_headers = { "X-Example-Header" = "example-value" }
|
||||
env_http_headers = { "X-Example-Features": "EXAMPLE_FEATURES" }
|
||||
```
|
||||
|
||||
### Per-provider network tuning
|
||||
|
||||
The following optional settings control retry behaviour and streaming idle timeouts **per model provider**. They must be specified inside the corresponding `[model_providers.<id>]` block in `config.toml`. (Older releases accepted top‑level keys; those are now ignored.)
|
||||
|
||||
Example:
|
||||
|
||||
```toml
|
||||
[model_providers.openai]
|
||||
name = "OpenAI"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
env_key = "OPENAI_API_KEY"
|
||||
# network tuning overrides (all optional; falls back to built‑in defaults)
|
||||
request_max_retries = 4 # retry failed HTTP requests
|
||||
stream_max_retries = 10 # retry dropped SSE streams
|
||||
stream_idle_timeout_ms = 300000 # 5m idle timeout
|
||||
```
|
||||
|
||||
#### request_max_retries
|
||||
How many times Codex will retry a failed HTTP request to the model provider. Defaults to `4`.
|
||||
|
||||
#### stream_max_retries
|
||||
Number of times Codex will attempt to reconnect when a streaming response is interrupted. Defaults to `10`.
|
||||
|
||||
#### stream_idle_timeout_ms
|
||||
How long Codex will wait for activity on a streaming response before treating the connection as lost. Defaults to `300_000` (5 minutes).
|
||||
|
||||
## model_provider
|
||||
|
||||
Identifies which provider to use from the `model_providers` map. Defaults to `"openai"`. You can override the `base_url` for the built-in `openai` provider via the `OPENAI_BASE_URL` environment variable.
|
||||
@@ -470,7 +444,7 @@ Currently, `"vscode"` is the default, though Codex does not verify VS Code is in
|
||||
|
||||
## hide_agent_reasoning
|
||||
|
||||
Codex intermittently emits "reasoning" events that show the model's internal "thinking" before it produces a final answer. Some users may find these events distracting, especially in CI logs or minimal terminal output.
|
||||
Codex intermittently emits "reasoning" events that show the model’s internal "thinking" before it produces a final answer. Some users may find these events distracting, especially in CI logs or minimal terminal output.
|
||||
|
||||
Setting `hide_agent_reasoning` to `true` suppresses these events in **both** the TUI as well as the headless `exec` sub-command:
|
||||
|
||||
|
||||
@@ -22,14 +22,12 @@ env-flags = "0.1.1"
|
||||
eventsource-stream = "0.2.3"
|
||||
fs2 = "0.4.3"
|
||||
futures = "0.3"
|
||||
libc = "0.2.174"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
mime_guess = "2.0"
|
||||
rand = "0.9"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha1 = "0.10.6"
|
||||
strum_macros = "0.27.1"
|
||||
thiserror = "2.0.12"
|
||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||
@@ -67,5 +65,4 @@ predicates = "3"
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
walkdir = "2.5.0"
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -21,6 +21,8 @@ use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::openai_tools::create_tools_json_for_chat_completions_api;
|
||||
@@ -119,7 +121,6 @@ pub(crate) async fn stream_chat_completions(
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = provider.request_max_retries();
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
@@ -135,11 +136,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
));
|
||||
tokio::spawn(process_chat_sse(stream, tx_event));
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
Ok(res) => {
|
||||
@@ -149,7 +146,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
@@ -165,7 +162,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
@@ -178,15 +175,14 @@ pub(crate) async fn stream_chat_completions(
|
||||
/// Lightweight SSE processor for the Chat Completions streaming format. The
|
||||
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
|
||||
/// of the pipeline can stay agnostic of the underlying wire format.
|
||||
async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
) where
|
||||
async fn process_chat_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
|
||||
where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
|
||||
// State to accumulate a function call across streaming chunks.
|
||||
// OpenAI may split the `arguments` string over multiple `delta` events
|
||||
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
|
||||
|
||||
@@ -15,7 +15,6 @@ use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::chat_completions::AggregateStreamExt;
|
||||
use crate::chat_completions::stream_chat_completions;
|
||||
@@ -30,6 +29,8 @@ use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::models::ResponseItem;
|
||||
@@ -43,7 +44,6 @@ pub struct ModelClient {
|
||||
config: Arc<Config>,
|
||||
client: reqwest::Client,
|
||||
provider: ModelProviderInfo,
|
||||
session_id: Uuid,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
}
|
||||
@@ -54,13 +54,11 @@ impl ModelClient {
|
||||
provider: ModelProviderInfo,
|
||||
effort: ReasoningEffortConfig,
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_id: Uuid,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
client: reqwest::Client::new(),
|
||||
provider,
|
||||
session_id,
|
||||
effort,
|
||||
summary,
|
||||
}
|
||||
@@ -111,7 +109,7 @@ impl ModelClient {
|
||||
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
|
||||
// short circuit for tests
|
||||
warn!(path, "Streaming from fixture");
|
||||
return stream_from_fixture(path, self.provider.clone()).await;
|
||||
return stream_from_fixture(path).await;
|
||||
}
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
||||
@@ -138,7 +136,6 @@ impl ModelClient {
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = self.provider.request_max_retries();
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
@@ -146,33 +143,17 @@ impl ModelClient {
|
||||
.provider
|
||||
.create_request_builder(&self.client)?
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header("session_id", self.session_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload);
|
||||
|
||||
let res = req_builder.send().await;
|
||||
if let Ok(resp) = &res {
|
||||
trace!(
|
||||
"Response status: {}, request-id: {}",
|
||||
resp.status(),
|
||||
resp.headers()
|
||||
.get("x-request-id")
|
||||
.map(|v| v.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
|
||||
// spawn task to process SSE
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
self.provider.stream_idle_timeout(),
|
||||
));
|
||||
tokio::spawn(process_sse(stream, tx_event));
|
||||
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
@@ -191,7 +172,7 @@ impl ModelClient {
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
@@ -208,7 +189,7 @@ impl ModelClient {
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
@@ -217,10 +198,6 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_provider(&self) -> ModelProviderInfo {
|
||||
self.provider.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
@@ -272,16 +249,14 @@ struct ResponseCompletedOutputTokensDetails {
|
||||
reasoning_tokens: u64,
|
||||
}
|
||||
|
||||
async fn process_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
) where
|
||||
async fn process_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
|
||||
where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
// If the stream stays completely silent for an extended period treat it as disconnected.
|
||||
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
// The response id returned from the "complete" message.
|
||||
let mut response_completed: Option<ResponseCompleted> = None;
|
||||
|
||||
@@ -342,7 +317,7 @@ async fn process_sse<S>(
|
||||
// duplicated `output` array embedded in the `response.completed`
|
||||
// payload. That produced two concrete issues:
|
||||
// 1. No real‑time streaming – the user only saw output after the
|
||||
// entire turn had finished, which broke the "typing" UX and
|
||||
// entire turn had finished, which broke the “typing” UX and
|
||||
// made long‑running turns look stalled.
|
||||
// 2. Duplicate `function_call_output` items – both the
|
||||
// individual *and* the completed array were forwarded, which
|
||||
@@ -385,19 +360,6 @@ async fn process_sse<S>(
|
||||
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
|
||||
}
|
||||
}
|
||||
"response.failed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
let error = resp_val
|
||||
.get("error")
|
||||
.and_then(|v| v.get("message"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("response.failed event received");
|
||||
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(error.to_string())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
// Final response completed – includes array of output items & id
|
||||
"response.completed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
@@ -428,10 +390,7 @@ async fn process_sse<S>(
|
||||
}
|
||||
|
||||
/// used in tests to stream from a text SSE file
|
||||
async fn stream_from_fixture(
|
||||
path: impl AsRef<Path>,
|
||||
provider: ModelProviderInfo,
|
||||
) -> Result<ResponseStream> {
|
||||
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let f = std::fs::File::open(path.as_ref())?;
|
||||
let lines = std::io::BufReader::new(f).lines();
|
||||
@@ -445,11 +404,7 @@ async fn stream_from_fixture(
|
||||
|
||||
let rdr = std::io::Cursor::new(content);
|
||||
let stream = ReaderStream::new(rdr).map_err(CodexErr::Io);
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
));
|
||||
tokio::spawn(process_sse(stream, tx_event));
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
|
||||
@@ -469,10 +424,7 @@ mod tests {
|
||||
|
||||
/// Runs the SSE parser on pre-chunked byte slices and returns every event
|
||||
/// (including any final `Err` from a stream-closure check).
|
||||
async fn collect_events(
|
||||
chunks: &[&[u8]],
|
||||
provider: ModelProviderInfo,
|
||||
) -> Vec<Result<ResponseEvent>> {
|
||||
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
|
||||
let mut builder = IoBuilder::new();
|
||||
for chunk in chunks {
|
||||
builder.read(chunk);
|
||||
@@ -481,7 +433,7 @@ mod tests {
|
||||
let reader = builder.build();
|
||||
let stream = ReaderStream::new(reader).map_err(CodexErr::Io);
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(16);
|
||||
tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout()));
|
||||
tokio::spawn(process_sse(stream, tx));
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
@@ -492,10 +444,7 @@ mod tests {
|
||||
|
||||
/// Builds an in-memory SSE stream from JSON fixtures and returns only the
|
||||
/// successfully parsed events (panics on internal channel errors).
|
||||
async fn run_sse(
|
||||
events: Vec<serde_json::Value>,
|
||||
provider: ModelProviderInfo,
|
||||
) -> Vec<ResponseEvent> {
|
||||
async fn run_sse(events: Vec<serde_json::Value>) -> Vec<ResponseEvent> {
|
||||
let mut body = String::new();
|
||||
for e in events {
|
||||
let kind = e
|
||||
@@ -511,7 +460,7 @@ mod tests {
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(8);
|
||||
let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io);
|
||||
tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout()));
|
||||
tokio::spawn(process_sse(stream, tx));
|
||||
|
||||
let mut out = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
@@ -556,25 +505,7 @@ mod tests {
|
||||
let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n");
|
||||
let sse3 = format!("event: response.completed\ndata: {completed}\n\n");
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: "https://test.com".to_string(),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
};
|
||||
|
||||
let events = collect_events(
|
||||
&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()],
|
||||
provider,
|
||||
)
|
||||
.await;
|
||||
let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 3);
|
||||
|
||||
@@ -615,21 +546,8 @@ mod tests {
|
||||
.to_string();
|
||||
|
||||
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: "https://test.com".to_string(),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
};
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()], provider).await;
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 2);
|
||||
|
||||
@@ -717,21 +635,7 @@ mod tests {
|
||||
let mut evs = vec![case.event];
|
||||
evs.push(completed.clone());
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".to_string(),
|
||||
base_url: "https://test.com".to_string(),
|
||||
env_key: Some("TEST_API_KEY".to_string()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(1000),
|
||||
};
|
||||
|
||||
let out = run_sse(evs, provider).await;
|
||||
let out = run_sse(evs).await;
|
||||
assert_eq!(out.len(), case.expected_len, "case {}", case.name);
|
||||
assert!(
|
||||
(case.expect_first)(&out[0]),
|
||||
|
||||
@@ -34,18 +34,11 @@ pub struct Prompt {
|
||||
/// the "fully qualified" tool name (i.e., prefixed with the server name),
|
||||
/// which should be reported to the model in place of Tool::name.
|
||||
pub extra_tools: HashMap<String, mcp_types::Tool>,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> {
|
||||
let base = self
|
||||
.base_instructions_override
|
||||
.as_deref()
|
||||
.unwrap_or(BASE_INSTRUCTIONS);
|
||||
let mut sections: Vec<&str> = vec![base];
|
||||
let mut sections: Vec<&str> = vec![BASE_INSTRUCTIONS];
|
||||
if let Some(ref user) = self.user_instructions {
|
||||
sections.push(user);
|
||||
}
|
||||
|
||||
@@ -49,7 +49,9 @@ use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::exec_env::create_env;
|
||||
use crate::flags::OPENAI_STREAM_MAX_RETRIES;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::try_parse_fully_qualified_tool_name;
|
||||
use crate::mcp_tool_call::handle_mcp_tool_call;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::FunctionCallOutputPayload;
|
||||
@@ -101,37 +103,26 @@ impl Codex {
|
||||
/// Spawn a new [`Codex`] and initialize the session. Returns the instance
|
||||
/// of `Codex` and the ID of the `SessionInitialized` event that was
|
||||
/// submitted to start the session.
|
||||
pub async fn spawn(config: Config, ctrl_c: Arc<Notify>) -> CodexResult<(Codex, String, Uuid)> {
|
||||
// experimental resume path (undocumented)
|
||||
let resume_path = config.experimental_resume.clone();
|
||||
info!("resume_path: {resume_path:?}");
|
||||
pub async fn spawn(config: Config, ctrl_c: Arc<Notify>) -> CodexResult<(Codex, String)> {
|
||||
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
||||
let (tx_event, rx_event) = async_channel::bounded(1600);
|
||||
|
||||
let user_instructions = get_user_instructions(&config).await;
|
||||
|
||||
let instructions = get_user_instructions(&config).await;
|
||||
let configure_session = Op::ConfigureSession {
|
||||
provider: config.model_provider.clone(),
|
||||
model: config.model.clone(),
|
||||
model_reasoning_effort: config.model_reasoning_effort,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
user_instructions,
|
||||
base_instructions: config.base_instructions.clone(),
|
||||
instructions,
|
||||
approval_policy: config.approval_policy,
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
disable_response_storage: config.disable_response_storage,
|
||||
notify: config.notify.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
resume_path: resume_path.clone(),
|
||||
};
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
// Generate a unique ID for the lifetime of this Codex session.
|
||||
let session_id = Uuid::new_v4();
|
||||
tokio::spawn(submission_loop(
|
||||
session_id, config, rx_sub, tx_event, ctrl_c,
|
||||
));
|
||||
tokio::spawn(submission_loop(config, rx_sub, tx_event, ctrl_c));
|
||||
let codex = Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub,
|
||||
@@ -139,7 +130,7 @@ impl Codex {
|
||||
};
|
||||
let init_id = codex.submit(configure_session).await?;
|
||||
|
||||
Ok((codex, init_id, session_id))
|
||||
Ok((codex, init_id))
|
||||
}
|
||||
|
||||
/// Submit the `op` wrapped in a `Submission` with a unique ID.
|
||||
@@ -185,8 +176,7 @@ pub(crate) struct Session {
|
||||
/// the model as well as sandbox policies are resolved against this path
|
||||
/// instead of `std::env::current_dir()`.
|
||||
cwd: PathBuf,
|
||||
base_instructions: Option<String>,
|
||||
user_instructions: Option<String>,
|
||||
instructions: Option<String>,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
shell_environment_policy: ShellEnvironmentPolicy,
|
||||
@@ -318,30 +308,24 @@ impl Session {
|
||||
/// transcript, if enabled.
|
||||
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||
debug!("Recording items for conversation: {items:?}");
|
||||
self.record_state_snapshot(items).await;
|
||||
self.record_rollout_items(items).await;
|
||||
|
||||
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
transcript.record_items(items);
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
||||
let snapshot = {
|
||||
let state = self.state.lock().unwrap();
|
||||
crate::rollout::SessionStateSnapshot {
|
||||
previous_response_id: state.previous_response_id.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
/// Append the given items to the session's rollout transcript (if enabled)
|
||||
/// and persist them to disk.
|
||||
async fn record_rollout_items(&self, items: &[ResponseItem]) {
|
||||
// Clone the recorder outside of the mutex so we don't hold the lock
|
||||
// across an await point (MutexGuard is not Send).
|
||||
let recorder = {
|
||||
let guard = self.rollout.lock().unwrap();
|
||||
guard.as_ref().cloned()
|
||||
};
|
||||
|
||||
if let Some(rec) = recorder {
|
||||
if let Err(e) = rec.record_state(snapshot).await {
|
||||
error!("failed to record rollout state: {e:#}");
|
||||
}
|
||||
if let Err(e) = rec.record_items(items).await {
|
||||
error!("failed to record rollout items: {e:#}");
|
||||
}
|
||||
@@ -529,12 +513,14 @@ impl AgentTask {
|
||||
}
|
||||
|
||||
async fn submission_loop(
|
||||
mut session_id: Uuid,
|
||||
config: Arc<Config>,
|
||||
rx_sub: Receiver<Submission>,
|
||||
tx_event: Sender<Event>,
|
||||
ctrl_c: Arc<Notify>,
|
||||
) {
|
||||
// Generate a unique ID for the lifetime of this Codex session.
|
||||
let session_id = Uuid::new_v4();
|
||||
|
||||
let mut sess: Option<Arc<Session>> = None;
|
||||
// shorthand - send an event when there is no active session
|
||||
let send_no_session_event = |sub_id: String| async {
|
||||
@@ -580,18 +566,14 @@ async fn submission_loop(
|
||||
model,
|
||||
model_reasoning_effort,
|
||||
model_reasoning_summary,
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
instructions,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
disable_response_storage,
|
||||
notify,
|
||||
cwd,
|
||||
resume_path,
|
||||
} => {
|
||||
info!(
|
||||
"Configuring session: model={model}; provider={provider:?}; resume={resume_path:?}"
|
||||
);
|
||||
info!("Configuring session: model={model}; provider={provider:?}");
|
||||
if !cwd.is_absolute() {
|
||||
let message = format!("cwd is not absolute: {cwd:?}");
|
||||
error!(message);
|
||||
@@ -604,50 +586,12 @@ async fn submission_loop(
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Optionally resume an existing rollout.
|
||||
let mut restored_items: Option<Vec<ResponseItem>> = None;
|
||||
let mut restored_prev_id: Option<String> = None;
|
||||
let rollout_recorder: Option<RolloutRecorder> =
|
||||
if let Some(path) = resume_path.as_ref() {
|
||||
match RolloutRecorder::resume(path).await {
|
||||
Ok((rec, saved)) => {
|
||||
session_id = saved.session_id;
|
||||
restored_prev_id = saved.state.previous_response_id;
|
||||
if !saved.items.is_empty() {
|
||||
restored_items = Some(saved.items);
|
||||
}
|
||||
Some(rec)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("failed to resume rollout from {path:?}: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let rollout_recorder = match rollout_recorder {
|
||||
Some(rec) => Some(rec),
|
||||
None => {
|
||||
match RolloutRecorder::new(&config, session_id, user_instructions.clone())
|
||||
.await
|
||||
{
|
||||
Ok(r) => Some(r),
|
||||
Err(e) => {
|
||||
warn!("failed to initialise rollout recorder: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client = ModelClient::new(
|
||||
config.clone(),
|
||||
provider.clone(),
|
||||
model_reasoning_effort,
|
||||
model_reasoning_summary,
|
||||
session_id,
|
||||
);
|
||||
|
||||
// abort any current running session and clone its state
|
||||
@@ -701,12 +645,26 @@ async fn submission_loop(
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to create a RolloutRecorder *before* moving the
|
||||
// `instructions` value into the Session struct.
|
||||
// TODO: if ConfigureSession is sent twice, we will create an
|
||||
// overlapping rollout file. Consider passing RolloutRecorder
|
||||
// from above.
|
||||
let rollout_recorder =
|
||||
match RolloutRecorder::new(&config, session_id, instructions.clone()).await {
|
||||
Ok(r) => Some(r),
|
||||
Err(e) => {
|
||||
warn!("failed to initialise rollout recorder: {e}");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
sess = Some(Arc::new(Session {
|
||||
client,
|
||||
tx_event: tx_event.clone(),
|
||||
ctrl_c: Arc::clone(&ctrl_c),
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
instructions,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
shell_environment_policy: config.shell_environment_policy.clone(),
|
||||
@@ -719,19 +677,6 @@ async fn submission_loop(
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
}));
|
||||
|
||||
// Patch restored state into the newly created session.
|
||||
if let Some(sess_arc) = &sess {
|
||||
if restored_prev_id.is_some() || restored_items.is_some() {
|
||||
let mut st = sess_arc.state.lock().unwrap();
|
||||
st.previous_response_id = restored_prev_id;
|
||||
if let (Some(hist), Some(items)) =
|
||||
(st.zdr_transcript.as_mut(), restored_items.as_ref())
|
||||
{
|
||||
hist.record_items(items.iter());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gather history metadata for SessionConfiguredEvent.
|
||||
let (history_log_id, history_entry_count) =
|
||||
crate::message_history::history_metadata(&config).await;
|
||||
@@ -800,8 +745,6 @@ async fn submission_loop(
|
||||
}
|
||||
}
|
||||
Op::AddToHistory { text } => {
|
||||
// TODO: What should we do if we got AddToHistory before ConfigureSession?
|
||||
// currently, if ConfigureSession has resume path, this history will be ignored
|
||||
let id = session_id;
|
||||
let config = config.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -977,17 +920,15 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
let (content, success): (String, Option<bool>) = match result {
|
||||
Ok(CallToolResult {
|
||||
content,
|
||||
is_error,
|
||||
structured_content: _,
|
||||
}) => match serde_json::to_string(content) {
|
||||
Ok(content) => (content, *is_error),
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize MCP tool call output: {e}");
|
||||
(e.to_string(), Some(true))
|
||||
Ok(CallToolResult { content, is_error }) => {
|
||||
match serde_json::to_string(content) {
|
||||
Ok(content) => (content, *is_error),
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize MCP tool call output: {e}");
|
||||
(e.to_string(), Some(true))
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(e) => (e.clone(), Some(true)),
|
||||
};
|
||||
items_to_record_in_conversation_history.push(
|
||||
@@ -1074,10 +1015,9 @@ async fn run_turn(
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
prev_id,
|
||||
user_instructions: sess.user_instructions.clone(),
|
||||
user_instructions: sess.instructions.clone(),
|
||||
store,
|
||||
extra_tools,
|
||||
base_instructions_override: sess.base_instructions.clone(),
|
||||
};
|
||||
|
||||
let mut retries = 0;
|
||||
@@ -1087,13 +1027,12 @@ async fn run_turn(
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e) => {
|
||||
// Use the configured provider-specific stream retry budget.
|
||||
let max_retries = sess.client.get_provider().stream_max_retries();
|
||||
if retries < max_retries {
|
||||
if retries < *OPENAI_STREAM_MAX_RETRIES {
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
warn!(
|
||||
"stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...",
|
||||
"stream disconnected - retrying turn ({retries}/{} in {delay:?})...",
|
||||
*OPENAI_STREAM_MAX_RETRIES
|
||||
);
|
||||
|
||||
// Surface retry information to any UI/front‑end so the
|
||||
@@ -1102,7 +1041,8 @@ async fn run_turn(
|
||||
sess.notify_background_event(
|
||||
&sub_id,
|
||||
format!(
|
||||
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…"
|
||||
"stream error: {e}; retrying {retries}/{} in {:?}…",
|
||||
*OPENAI_STREAM_MAX_RETRIES, delay
|
||||
),
|
||||
)
|
||||
.await;
|
||||
@@ -1184,28 +1124,7 @@ async fn try_run_turn(
|
||||
let mut stream = sess.client.clone().stream(&prompt).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
loop {
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().await;
|
||||
let Some(event) = event else {
|
||||
// Channel closed without yielding a final Completed event or explicit error.
|
||||
// Treat as a disconnected stream so the caller can retry.
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(ev) => ev,
|
||||
Err(e) => {
|
||||
// Propagate the underlying stream error to the caller (run_turn), which
|
||||
// will apply the configured `stream_max_retries` policy.
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(Ok(event)) = stream.next().await {
|
||||
match event {
|
||||
ResponseEvent::Created => {
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
@@ -1246,7 +1165,7 @@ async fn try_run_turn(
|
||||
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.previous_response_id = Some(response_id);
|
||||
return Ok(output);
|
||||
break;
|
||||
}
|
||||
ResponseEvent::OutputTextDelta(delta) => {
|
||||
let event = Event {
|
||||
@@ -1264,6 +1183,7 @@ async fn try_run_turn(
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
@@ -1366,13 +1286,13 @@ async fn handle_function_call(
|
||||
let params = match parse_container_exec_arguments(arguments, sess, &call_id) {
|
||||
Ok(params) => params,
|
||||
Err(output) => {
|
||||
return *output;
|
||||
return output;
|
||||
}
|
||||
};
|
||||
handle_container_exec_with_params(params, sess, sub_id, call_id).await
|
||||
}
|
||||
_ => {
|
||||
match sess.mcp_connection_manager.parse_tool_name(&name) {
|
||||
match try_parse_fully_qualified_tool_name(&name) {
|
||||
Some((server, tool_name)) => {
|
||||
// TODO(mbolin): Determine appropriate timeout for tool call.
|
||||
let timeout = None;
|
||||
@@ -1409,7 +1329,7 @@ fn parse_container_exec_arguments(
|
||||
arguments: String,
|
||||
sess: &Session,
|
||||
call_id: &str,
|
||||
) -> Result<ExecParams, Box<ResponseInputItem>> {
|
||||
) -> Result<ExecParams, ResponseInputItem> {
|
||||
// parse command
|
||||
match serde_json::from_str::<ShellToolCallParams>(&arguments) {
|
||||
Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, sess)),
|
||||
@@ -1422,7 +1342,7 @@ fn parse_container_exec_arguments(
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
Err(Box::new(output))
|
||||
Err(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,16 +6,15 @@ use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::util::notify_on_sigint;
|
||||
use tokio::sync::Notify;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Spawn a new [`Codex`] and initialize the session.
|
||||
///
|
||||
/// Returns the wrapped [`Codex`] **and** the `SessionInitialized` event that
|
||||
/// is received as a response to the initial `ConfigureSession` submission so
|
||||
/// that callers can surface the information to the UI.
|
||||
pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc<Notify>, Uuid)> {
|
||||
pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc<Notify>)> {
|
||||
let ctrl_c = notify_on_sigint();
|
||||
let (codex, init_id, session_id) = Codex::spawn(config, ctrl_c.clone()).await?;
|
||||
let (codex, init_id) = Codex::spawn(config, ctrl_c.clone()).await?;
|
||||
|
||||
// The first event must be `SessionInitialized`. Validate and forward it to
|
||||
// the caller so that they can display it in the conversation history.
|
||||
@@ -34,5 +33,5 @@ pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc<Not
|
||||
));
|
||||
}
|
||||
|
||||
Ok((codex, event, ctrl_c, session_id))
|
||||
Ok((codex, event, ctrl_c))
|
||||
}
|
||||
|
||||
@@ -63,10 +63,7 @@ pub struct Config {
|
||||
pub disable_response_storage: bool,
|
||||
|
||||
/// User-provided instructions from instructions.md.
|
||||
pub user_instructions: Option<String>,
|
||||
|
||||
/// Base instructions override.
|
||||
pub base_instructions: Option<String>,
|
||||
pub instructions: Option<String>,
|
||||
|
||||
/// Optional external notifier command. When set, Codex will spawn this
|
||||
/// program after each completed *turn* (i.e. when the agent finishes
|
||||
@@ -140,9 +137,6 @@ pub struct Config {
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: String,
|
||||
|
||||
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
|
||||
pub experimental_resume: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -327,12 +321,6 @@ pub struct ConfigToml {
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
|
||||
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
|
||||
pub experimental_resume: Option<PathBuf>,
|
||||
|
||||
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -365,7 +353,6 @@ pub struct ConfigOverrides {
|
||||
pub model_provider: Option<String>,
|
||||
pub config_profile: Option<String>,
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub base_instructions: Option<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -376,7 +363,7 @@ impl Config {
|
||||
overrides: ConfigOverrides,
|
||||
codex_home: PathBuf,
|
||||
) -> std::io::Result<Self> {
|
||||
let user_instructions = Self::load_instructions(Some(&codex_home));
|
||||
let instructions = Self::load_instructions(Some(&codex_home));
|
||||
|
||||
// Destructure ConfigOverrides fully to ensure all overrides are applied.
|
||||
let ConfigOverrides {
|
||||
@@ -387,7 +374,6 @@ impl Config {
|
||||
model_provider,
|
||||
config_profile: config_profile_key,
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions,
|
||||
} = overrides;
|
||||
|
||||
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
|
||||
@@ -462,13 +448,6 @@ impl Config {
|
||||
.as_ref()
|
||||
.map(|info| info.max_output_tokens)
|
||||
});
|
||||
|
||||
let experimental_resume = cfg.experimental_resume;
|
||||
|
||||
let base_instructions = base_instructions.or(Self::get_base_instructions(
|
||||
cfg.experimental_instructions_file.as_ref(),
|
||||
));
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
model_context_window,
|
||||
@@ -487,8 +466,7 @@ impl Config {
|
||||
.or(cfg.disable_response_storage)
|
||||
.unwrap_or(false),
|
||||
notify: cfg.notify,
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
instructions,
|
||||
mcp_servers: cfg.mcp_servers,
|
||||
model_providers,
|
||||
project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES),
|
||||
@@ -516,8 +494,6 @@ impl Config {
|
||||
.chatgpt_base_url
|
||||
.or(cfg.chatgpt_base_url)
|
||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||
|
||||
experimental_resume,
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
@@ -538,15 +514,6 @@ impl Config {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_base_instructions(path: Option<&PathBuf>) -> Option<String> {
|
||||
let path = path.as_ref()?;
|
||||
|
||||
std::fs::read_to_string(path)
|
||||
.ok()
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_model() -> String {
|
||||
@@ -715,9 +682,6 @@ name = "OpenAI using Chat Completions"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
env_key = "OPENAI_API_KEY"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 4 # retry failed HTTP requests
|
||||
stream_max_retries = 10 # retry dropped SSE streams
|
||||
stream_idle_timeout_ms = 300000 # 5m idle timeout
|
||||
|
||||
[profiles.o3]
|
||||
model = "o3"
|
||||
@@ -758,9 +722,6 @@ disable_response_storage = true
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(4),
|
||||
stream_max_retries: Some(10),
|
||||
stream_idle_timeout_ms: Some(300_000),
|
||||
};
|
||||
let model_provider_map = {
|
||||
let mut model_provider_map = built_in_model_providers();
|
||||
@@ -823,7 +784,7 @@ disable_response_storage = true
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
@@ -839,8 +800,6 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::Detailed,
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
},
|
||||
o3_profile_config
|
||||
);
|
||||
@@ -871,7 +830,7 @@ disable_response_storage = true
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: false,
|
||||
user_instructions: None,
|
||||
instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
@@ -887,8 +846,6 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||
@@ -934,7 +891,7 @@ disable_response_storage = true
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
shell_environment_policy: ShellEnvironmentPolicy::default(),
|
||||
disable_response_storage: true,
|
||||
user_instructions: None,
|
||||
instructions: None,
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
@@ -950,8 +907,6 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
experimental_resume: None,
|
||||
base_instructions: None,
|
||||
};
|
||||
|
||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||
|
||||
@@ -384,31 +384,6 @@ async fn spawn_child_async(
|
||||
cmd.env(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR, "1");
|
||||
}
|
||||
|
||||
// If this Codex process dies (including being killed via SIGKILL), we want
|
||||
// any child processes that were spawned as part of a `"shell"` tool call
|
||||
// to also be terminated.
|
||||
|
||||
// This relies on prctl(2), so it only works on Linux.
|
||||
#[cfg(target_os = "linux")]
|
||||
unsafe {
|
||||
cmd.pre_exec(|| {
|
||||
// This prctl call effectively requests, "deliver SIGTERM when my
|
||||
// current parent dies."
|
||||
if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) == -1 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
// Though if there was a race condition and this pre_exec() block is
|
||||
// run _after_ the parent (i.e., the Codex process) has already
|
||||
// exited, then the parent is the _init_ process (which will never
|
||||
// die), so we should just terminate the child process now.
|
||||
if libc::getppid() == 1 {
|
||||
libc::raise(libc::SIGTERM);
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
match stdio_policy {
|
||||
StdioPolicy::RedirectForShellTool => {
|
||||
// Do not create a file descriptor for stdin because otherwise some
|
||||
|
||||
@@ -11,6 +11,14 @@ env_flags! {
|
||||
pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
|
||||
value.parse().map(Duration::from_millis)
|
||||
};
|
||||
pub OPENAI_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
pub OPENAI_STREAM_MAX_RETRIES: u64 = 10;
|
||||
|
||||
// We generally don't want to disconnect; this updates the timeout to be five minutes
|
||||
// which matches the upstream typescript codex impl.
|
||||
pub OPENAI_STREAM_IDLE_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
|
||||
value.parse().map(Duration::from_millis)
|
||||
};
|
||||
|
||||
/// Fixture path for offline tests (see client.rs).
|
||||
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
|
||||
|
||||
@@ -23,9 +23,9 @@ fn is_safe_to_call_with_exec(command: &[String]) -> bool {
|
||||
let cmd0 = command.first().map(String::as_str);
|
||||
|
||||
match cmd0 {
|
||||
Some("cat" | "cd" | "echo" | "grep" | "head" | "ls" | "pwd" | "tail" | "wc" | "which") => {
|
||||
true
|
||||
}
|
||||
Some(
|
||||
"cat" | "cd" | "echo" | "grep" | "head" | "ls" | "pwd" | "rg" | "tail" | "wc" | "which",
|
||||
) => true,
|
||||
|
||||
Some("find") => {
|
||||
// Certain options to `find` can delete files, write to files, or
|
||||
@@ -46,29 +46,6 @@ fn is_safe_to_call_with_exec(command: &[String]) -> bool {
|
||||
.any(|arg| UNSAFE_FIND_OPTIONS.contains(&arg.as_str()))
|
||||
}
|
||||
|
||||
// Ripgrep
|
||||
Some("rg") => {
|
||||
const UNSAFE_RIPGREP_OPTIONS_WITH_ARGS: &[&str] = &[
|
||||
// Takes an arbitrary command that is executed for each match.
|
||||
"--pre",
|
||||
// Takes a command that can be used to obtain the local hostname.
|
||||
"--hostname-bin",
|
||||
];
|
||||
const UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS: &[&str] = &[
|
||||
// Calls out to other decompression tools, so do not auto-approve
|
||||
// out of an abundance of caution.
|
||||
"--search-zip",
|
||||
"-z",
|
||||
];
|
||||
|
||||
!command.iter().any(|arg| {
|
||||
UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS.contains(&arg.as_str())
|
||||
|| UNSAFE_RIPGREP_OPTIONS_WITH_ARGS
|
||||
.iter()
|
||||
.any(|&opt| arg == opt || arg.starts_with(&format!("{opt}=")))
|
||||
})
|
||||
}
|
||||
|
||||
// Git
|
||||
Some("git") => matches!(
|
||||
command.get(1).map(String::as_str),
|
||||
@@ -268,40 +245,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ripgrep_rules() {
|
||||
// Safe ripgrep invocations – none of the unsafe flags are present.
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
"rg",
|
||||
"Cargo.toml",
|
||||
"-n"
|
||||
])));
|
||||
|
||||
// Unsafe flags that do not take an argument (present verbatim).
|
||||
for args in [
|
||||
vec_str(&["rg", "--search-zip", "files"]),
|
||||
vec_str(&["rg", "-z", "files"]),
|
||||
] {
|
||||
assert!(
|
||||
!is_safe_to_call_with_exec(&args),
|
||||
"expected {args:?} to be considered unsafe due to zip-search flag",
|
||||
);
|
||||
}
|
||||
|
||||
// Unsafe flags that expect a value, provided in both split and = forms.
|
||||
for args in [
|
||||
vec_str(&["rg", "--pre", "pwned", "files"]),
|
||||
vec_str(&["rg", "--pre=pwned", "files"]),
|
||||
vec_str(&["rg", "--hostname-bin", "pwned", "files"]),
|
||||
vec_str(&["rg", "--hostname-bin=pwned", "files"]),
|
||||
] {
|
||||
assert!(
|
||||
!is_safe_to_call_with_exec(&args),
|
||||
"expected {args:?} to be considered unsafe due to external-command flag",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_lc_safe_examples() {
|
||||
assert!(is_known_safe_command(&vec_str(&["bash", "-lc", "ls"])));
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -17,13 +16,8 @@ use codex_mcp_client::McpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::Tool;
|
||||
|
||||
use serde_json::json;
|
||||
use sha1::Digest;
|
||||
use sha1::Sha1;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config_types::McpServerConfig;
|
||||
|
||||
@@ -32,8 +26,7 @@ use crate::config_types::McpServerConfig;
|
||||
///
|
||||
/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must
|
||||
/// choose a delimiter from this character set.
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__";
|
||||
|
||||
/// Timeout for the `tools/list` request.
|
||||
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
@@ -42,42 +35,16 @@ const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// spawned successfully.
|
||||
pub type ClientStartErrors = HashMap<String, anyhow::Error>;
|
||||
|
||||
fn qualify_tools(tools: Vec<ToolInfo>) -> HashMap<String, ToolInfo> {
|
||||
let mut used_names = HashSet::new();
|
||||
let mut qualified_tools = HashMap::new();
|
||||
for tool in tools {
|
||||
let mut qualified_name = format!(
|
||||
"{}{}{}",
|
||||
tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name
|
||||
);
|
||||
if qualified_name.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(qualified_name.as_bytes());
|
||||
let sha1 = hasher.finalize();
|
||||
let sha1_str = format!("{sha1:x}");
|
||||
|
||||
// Truncate to make room for the hash suffix
|
||||
let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len();
|
||||
|
||||
qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str);
|
||||
}
|
||||
|
||||
if used_names.contains(&qualified_name) {
|
||||
warn!("skipping duplicated tool {}", qualified_name);
|
||||
continue;
|
||||
}
|
||||
|
||||
used_names.insert(qualified_name.clone());
|
||||
qualified_tools.insert(qualified_name, tool);
|
||||
}
|
||||
|
||||
qualified_tools
|
||||
fn fully_qualified_tool_name(server: &str, tool: &str) -> String {
|
||||
format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}")
|
||||
}
|
||||
|
||||
struct ToolInfo {
|
||||
server_name: String,
|
||||
tool_name: String,
|
||||
tool: Tool,
|
||||
pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> {
|
||||
let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?;
|
||||
if server.is_empty() || tool.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((server.to_string(), tool.to_string()))
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`McpClient`] instances.
|
||||
@@ -90,7 +57,7 @@ pub(crate) struct McpConnectionManager {
|
||||
clients: HashMap<String, std::sync::Arc<McpClient>>,
|
||||
|
||||
/// Fully qualified tool name -> tool instance.
|
||||
tools: HashMap<String, ToolInfo>,
|
||||
tools: HashMap<String, Tool>,
|
||||
}
|
||||
|
||||
impl McpConnectionManager {
|
||||
@@ -136,14 +103,10 @@ impl McpConnectionManager {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
|
||||
// indicates this should be an empty object.
|
||||
elicitation: Some(json!({})),
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
@@ -178,9 +141,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
let all_tools = list_all_tools(&clients).await?;
|
||||
|
||||
let tools = qualify_tools(all_tools);
|
||||
let tools = list_all_tools(&clients).await?;
|
||||
|
||||
Ok((Self { clients, tools }, errors))
|
||||
}
|
||||
@@ -188,10 +149,7 @@ impl McpConnectionManager {
|
||||
/// Returns a single map that contains **all** tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|(name, tool)| (name.clone(), tool.tool.clone()))
|
||||
.collect()
|
||||
self.tools.clone()
|
||||
}
|
||||
|
||||
/// Invoke the tool indicated by the (server, tool) pair.
|
||||
@@ -213,19 +171,13 @@ impl McpConnectionManager {
|
||||
.await
|
||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
|
||||
}
|
||||
|
||||
pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.tools
|
||||
.get(tool_name)
|
||||
.map(|tool| (tool.server_name.clone(), tool.tool_name.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(
|
||||
pub async fn list_all_tools(
|
||||
clients: &HashMap<String, std::sync::Arc<McpClient>>,
|
||||
) -> Result<Vec<ToolInfo>> {
|
||||
) -> Result<HashMap<String, Tool>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn one task per server so we can query them concurrently. This
|
||||
@@ -242,19 +194,18 @@ async fn list_all_tools(
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: Vec<ToolInfo> = Vec::with_capacity(join_set.len());
|
||||
let mut aggregated: HashMap<String, Tool> = HashMap::with_capacity(join_set.len());
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
let (server_name, list_result) = join_res?;
|
||||
let list_result = list_result?;
|
||||
|
||||
for tool in list_result.tools {
|
||||
let tool_info = ToolInfo {
|
||||
server_name: server_name.clone(),
|
||||
tool_name: tool.name.clone(),
|
||||
tool,
|
||||
};
|
||||
aggregated.push(tool_info);
|
||||
// TODO(mbolin): escape tool names that contain invalid characters.
|
||||
let fq_name = fully_qualified_tool_name(&server_name, &tool.name);
|
||||
if aggregated.insert(fq_name.clone(), tool).is_some() {
|
||||
panic!("tool name collision for '{fq_name}': suspicious");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,92 +224,3 @@ fn is_valid_mcp_server_name(server_name: &str) -> bool {
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ToolInputSchema;
|
||||
|
||||
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
|
||||
ToolInfo {
|
||||
server_name: server_name.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
tool: Tool {
|
||||
annotations: None,
|
||||
description: Some(format!("Test tool: {tool_name}")),
|
||||
input_schema: ToolInputSchema {
|
||||
properties: None,
|
||||
required: None,
|
||||
r#type: "object".to_string(),
|
||||
},
|
||||
name: tool_name.to_string(),
|
||||
output_schema: None,
|
||||
title: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_short_non_duplicated_names() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "tool1"),
|
||||
create_test_tool("server1", "tool2"),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
assert_eq!(qualified_tools.len(), 2);
|
||||
assert!(qualified_tools.contains_key("server1__tool1"));
|
||||
assert!(qualified_tools.contains_key("server1__tool2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_duplicated_names_skipped() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "duplicate_tool"),
|
||||
create_test_tool("server1", "duplicate_tool"),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
// Only the first tool should remain, the second is skipped
|
||||
assert_eq!(qualified_tools.len(), 1);
|
||||
assert!(qualified_tools.contains_key("server1__duplicate_tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_long_names_same_server() {
|
||||
let server_name = "my_server";
|
||||
|
||||
let tools = vec![
|
||||
create_test_tool(
|
||||
server_name,
|
||||
"extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
|
||||
),
|
||||
create_test_tool(
|
||||
server_name,
|
||||
"yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
|
||||
),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
assert_eq!(qualified_tools.len(), 2);
|
||||
|
||||
let mut keys: Vec<_> = qualified_tools.keys().cloned().collect();
|
||||
keys.sort();
|
||||
|
||||
assert_eq!(keys[0].len(), 64);
|
||||
assert_eq!(
|
||||
keys[0],
|
||||
"my_server__extremely_lena02e507efc5a9de88637e436690364fd4219e4ef"
|
||||
);
|
||||
|
||||
assert_eq!(keys[1].len(), 64);
|
||||
assert_eq!(
|
||||
keys[1],
|
||||
"my_server__yet_another_e1c3987bd9c50b826cbe1687966f79f0c602d19ca"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
use crate::openai_api_key::get_openai_api_key;
|
||||
@@ -17,9 +16,6 @@ use crate::openai_api_key::get_openai_api_key;
|
||||
/// Value for the `OpenAI-Originator` header that is sent with requests to
|
||||
/// OpenAI.
|
||||
const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs";
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 10;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
|
||||
/// Wire protocol that the provider speaks. Most third-party services only
|
||||
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
|
||||
@@ -30,7 +26,7 @@ const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
/// The experimental "Responses" API exposed by OpenAI at `/v1/responses`.
|
||||
/// The experimental “Responses” API exposed by OpenAI at `/v1/responses`.
|
||||
Responses,
|
||||
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
@@ -68,16 +64,6 @@ pub struct ModelProviderInfo {
|
||||
/// value should be used. If the environment variable is not set, or the
|
||||
/// value is empty, the header will not be included in the request.
|
||||
pub env_http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
/// Maximum number of times to retry a failed HTTP request to this provider.
|
||||
pub request_max_retries: Option<u64>,
|
||||
|
||||
/// Number of times to retry reconnecting a dropped streaming response before failing.
|
||||
pub stream_max_retries: Option<u64>,
|
||||
|
||||
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
@@ -175,25 +161,6 @@ impl ModelProviderInfo {
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Effective maximum number of request retries for this provider.
|
||||
pub fn request_max_retries(&self) -> u64 {
|
||||
self.request_max_retries
|
||||
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
|
||||
}
|
||||
|
||||
/// Effective maximum number of stream reconnection attempts for this provider.
|
||||
pub fn stream_max_retries(&self) -> u64 {
|
||||
self.stream_max_retries
|
||||
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
|
||||
}
|
||||
|
||||
/// Effective idle timeout for streaming responses.
|
||||
pub fn stream_idle_timeout(&self) -> Duration {
|
||||
self.stream_idle_timeout_ms
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in default provider list.
|
||||
@@ -238,10 +205,6 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
// Use global defaults for retry/timeout unless overridden in config.toml.
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -271,9 +234,6 @@ base_url = "http://localhost:11434/v1"
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -299,9 +259,6 @@ query_params = { api-version = "2025-04-01-preview" }
|
||||
}),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
@@ -330,9 +287,6 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
env_http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
|
||||
}),
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
|
||||
@@ -27,16 +27,16 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
|
||||
/// string of instructions.
|
||||
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
match find_project_doc(config).await {
|
||||
Ok(Some(project_doc)) => match &config.user_instructions {
|
||||
Ok(Some(project_doc)) => match &config.instructions {
|
||||
Some(original_instructions) => Some(format!(
|
||||
"{original_instructions}{PROJECT_DOC_SEPARATOR}{project_doc}"
|
||||
)),
|
||||
None => Some(project_doc),
|
||||
},
|
||||
Ok(None) => config.user_instructions.clone(),
|
||||
Ok(None) => config.instructions.clone(),
|
||||
Err(e) => {
|
||||
error!("error trying to find project doc: {e:#}");
|
||||
config.user_instructions.clone()
|
||||
config.instructions.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -159,7 +159,7 @@ mod tests {
|
||||
config.cwd = root.path().to_path_buf();
|
||||
config.project_doc_max_bytes = limit;
|
||||
|
||||
config.user_instructions = instructions.map(ToOwned::to_owned);
|
||||
config.instructions = instructions.map(ToOwned::to_owned);
|
||||
config
|
||||
}
|
||||
|
||||
|
||||
@@ -44,12 +44,8 @@ pub enum Op {
|
||||
model_reasoning_effort: ReasoningEffortConfig,
|
||||
model_reasoning_summary: ReasoningSummaryConfig,
|
||||
|
||||
/// Model instructions that are appended to the base instructions.
|
||||
user_instructions: Option<String>,
|
||||
|
||||
/// Base instructions override.
|
||||
base_instructions: Option<String>,
|
||||
|
||||
/// Model instructions
|
||||
instructions: Option<String>,
|
||||
/// When to escalate for approval for execution
|
||||
approval_policy: AskForApproval,
|
||||
/// How to sandbox commands executed in the system
|
||||
@@ -73,10 +69,6 @@ pub enum Op {
|
||||
/// `ConfigureSession` operation so that the business-logic layer can
|
||||
/// operate deterministically.
|
||||
cwd: std::path::PathBuf,
|
||||
|
||||
/// Path to a rollout file to resume from.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
resume_path: Option<std::path::PathBuf>,
|
||||
},
|
||||
|
||||
/// Abort current task.
|
||||
|
||||
@@ -1,47 +1,33 @@
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
//! Functionality to persist a Codex conversation *rollout* – a linear list of
|
||||
//! [`ResponseItem`] objects exchanged during a session – to disk so that
|
||||
//! sessions can be replayed or inspected later (mirrors the behaviour of the
|
||||
//! upstream TypeScript implementation).
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::models::ResponseItem;
|
||||
|
||||
/// Folder inside `~/.codex` that holds saved rollouts.
|
||||
const SESSIONS_SUBDIR: &str = "sessions";
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||
pub struct SessionMeta {
|
||||
pub id: Uuid,
|
||||
pub timestamp: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {
|
||||
pub previous_response_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: Uuid,
|
||||
#[derive(Serialize)]
|
||||
struct SessionMeta {
|
||||
id: String,
|
||||
timestamp: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
instructions: Option<String>,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
@@ -55,13 +41,7 @@ pub struct SavedSession {
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<ResponseItem>),
|
||||
UpdateState(SessionStateSnapshot),
|
||||
tx: Sender<String>,
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
@@ -79,6 +59,7 @@ impl RolloutRecorder {
|
||||
timestamp,
|
||||
} = create_log_file(config, uuid)?;
|
||||
|
||||
// Build the static session metadata JSON first.
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
@@ -88,29 +69,46 @@ impl RolloutRecorder {
|
||||
|
||||
let meta = SessionMeta {
|
||||
timestamp,
|
||||
id: session_id,
|
||||
id: session_id.to_string(),
|
||||
instructions,
|
||||
};
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller’s thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
let (tx, mut rx) = mpsc::channel::<String>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(meta),
|
||||
));
|
||||
tokio::task::spawn(async move {
|
||||
let mut file = tokio::fs::File::from_std(file);
|
||||
|
||||
Ok(Self { tx })
|
||||
while let Some(line) = rx.recv().await {
|
||||
// Write line + newline, then flush to disk.
|
||||
if let Err(e) = file.write_all(line.as_bytes()).await {
|
||||
tracing::warn!("rollout writer: failed to write line: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = file.write_all(b"\n").await {
|
||||
tracing::warn!("rollout writer: failed to write newline: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = file.flush().await {
|
||||
tracing::warn!("rollout writer: failed to flush: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let recorder = Self { tx };
|
||||
// Ensure SessionMeta is the first item in the file.
|
||||
recorder.record_item(&meta).await?;
|
||||
Ok(recorder)
|
||||
}
|
||||
|
||||
/// Append `items` to the rollout file.
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
match item {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
@@ -119,86 +117,27 @@ impl RolloutRecorder {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
|
||||
| ResponseItem::FunctionCallOutput { .. } => {}
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
||||
// These should never be serialized.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
self.record_item(item).await?;
|
||||
}
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||
async fn record_item(&self, item: &impl Serialize) -> std::io::Result<()> {
|
||||
// Serialize the item to JSON first so that the writer thread only has
|
||||
// to perform the actual write.
|
||||
let json = serde_json::to_string(item)
|
||||
.map_err(|e| IoError::other(format!("failed to serialize response items: {e}")))?;
|
||||
|
||||
self.tx
|
||||
.send(RolloutCmd::UpdateState(state))
|
||||
.send(json)
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
}
|
||||
|
||||
pub async fn resume(path: &Path) -> std::io::Result<(Self, SavedSession)> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let session: SessionMeta = serde_json::from_str(meta_line)
|
||||
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
||||
let mut items = Vec::new();
|
||||
let mut state = SessionStateSnapshot::default();
|
||||
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if v.get("record_type")
|
||||
.and_then(|rt| rt.as_str())
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
||||
state = s
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if let Ok(item) = serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => items.push(item),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let saved = SavedSession {
|
||||
session: session.clone(),
|
||||
items: items.clone(),
|
||||
state: state.clone(),
|
||||
session_id: session.id,
|
||||
};
|
||||
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.read(true)
|
||||
.open(path)?;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
tokio::task::spawn(rollout_writer(tokio::fs::File::from_std(file), rx, None));
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok((Self { tx }, saved))
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout item: {e}")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,16 +153,14 @@ struct LogFileInfo {
|
||||
}
|
||||
|
||||
fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
// Resolve ~/.codex/sessions and create it if missing.
|
||||
let mut dir = config.codex_home.clone();
|
||||
dir.push(SESSIONS_SUBDIR);
|
||||
dir.push(timestamp.year().to_string());
|
||||
dir.push(format!("{:02}", u8::from(timestamp.month())));
|
||||
dir.push(format!("{:02}", timestamp.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
|
||||
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
|
||||
// compatibility with filesystems that do not allow colons in filenames.
|
||||
let format: &[FormatItem] =
|
||||
@@ -246,54 +183,3 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
mut file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
meta: Option<SessionMeta>,
|
||||
) {
|
||||
if let Some(meta) = meta {
|
||||
if let Ok(json) = serde_json::to_string(&meta) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
}
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => {
|
||||
if let Ok(json) = serde_json::to_string(&item) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
}
|
||||
}
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
RolloutCmd::UpdateState(state) => {
|
||||
#[derive(Serialize)]
|
||||
struct StateLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
state: &'a SessionStateSnapshot,
|
||||
}
|
||||
if let Ok(json) = serde_json::to_string(&StateLine {
|
||||
record_type: "state",
|
||||
state: &state,
|
||||
}) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,7 @@
|
||||
|
||||
use assert_cmd::Command as AssertCommand;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
use walkdir::WalkDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -121,241 +117,3 @@ async fn responses_api_stream_cli() {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
assert!(stdout.contains("fixture hello"));
|
||||
}
|
||||
|
||||
/// End-to-end: create a session (writes rollout), verify the file, then resume and confirm append.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_creates_and_checks_session_file() {
|
||||
// Honor sandbox network restrictions for CI parity with the other tests.
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Temp home so we read/write isolated session files.
|
||||
let home = TempDir::new().unwrap();
|
||||
|
||||
// 2. Unique marker we'll look for in the session log.
|
||||
let marker = format!("integration-test-{}", Uuid::new_v4());
|
||||
let prompt = format!("echo {marker}");
|
||||
|
||||
// 3. Use the same offline SSE fixture as responses_api_stream_cli so the test is hermetic.
|
||||
let fixture =
|
||||
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/cli_responses_fixture.sse");
|
||||
|
||||
// 4. Run the codex CLI through cargo (ensures the right bin is built) and invoke `exec`,
|
||||
// which is what records a session.
|
||||
let mut cmd = AssertCommand::new("cargo");
|
||||
cmd.arg("run")
|
||||
.arg("-p")
|
||||
.arg("codex-cli")
|
||||
.arg("--quiet")
|
||||
.arg("--")
|
||||
.arg("exec")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-C")
|
||||
.arg(env!("CARGO_MANIFEST_DIR"))
|
||||
.arg(&prompt);
|
||||
cmd.env("CODEX_HOME", home.path())
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||
// Required for CLI arg parsing even though fixture short-circuits network usage.
|
||||
.env("OPENAI_BASE_URL", "http://unused.local");
|
||||
|
||||
let output = cmd.output().unwrap();
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"codex-cli exec failed: {}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
// Wait for sessions dir to appear.
|
||||
let sessions_dir = home.path().join("sessions");
|
||||
let dir_deadline = Instant::now() + Duration::from_secs(5);
|
||||
while !sessions_dir.exists() && Instant::now() < dir_deadline {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
assert!(sessions_dir.exists(), "sessions directory never appeared");
|
||||
|
||||
// Find the session file that contains `marker`.
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
let mut matching_path: Option<std::path::PathBuf> = None;
|
||||
while Instant::now() < deadline && matching_path.is_none() {
|
||||
for entry in WalkDir::new(&sessions_dir) {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if !entry.file_type().is_file() {
|
||||
continue;
|
||||
}
|
||||
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||
continue;
|
||||
}
|
||||
let path = entry.path();
|
||||
let Ok(content) = std::fs::read_to_string(path) else {
|
||||
continue;
|
||||
};
|
||||
let mut lines = content.lines();
|
||||
if lines.next().is_none() {
|
||||
continue;
|
||||
}
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let item: serde_json::Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
|
||||
if let Some(c) = item.get("content") {
|
||||
if c.to_string().contains(&marker) {
|
||||
matching_path = Some(path.to_path_buf());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if matching_path.is_none() {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
|
||||
let path = match matching_path {
|
||||
Some(p) => p,
|
||||
None => panic!("No session file containing the marker was found"),
|
||||
};
|
||||
|
||||
// Basic sanity checks on location and metadata.
|
||||
let rel = match path.strip_prefix(&sessions_dir) {
|
||||
Ok(r) => r,
|
||||
Err(_) => panic!("session file should live under sessions/"),
|
||||
};
|
||||
let comps: Vec<String> = rel
|
||||
.components()
|
||||
.map(|c| c.as_os_str().to_string_lossy().into_owned())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
comps.len(),
|
||||
4,
|
||||
"Expected sessions/YYYY/MM/DD/<file>, got {rel:?}"
|
||||
);
|
||||
let year = &comps[0];
|
||||
let month = &comps[1];
|
||||
let day = &comps[2];
|
||||
assert!(
|
||||
year.len() == 4 && year.chars().all(|c| c.is_ascii_digit()),
|
||||
"Year dir not 4-digit numeric: {year}"
|
||||
);
|
||||
assert!(
|
||||
month.len() == 2 && month.chars().all(|c| c.is_ascii_digit()),
|
||||
"Month dir not zero-padded 2-digit numeric: {month}"
|
||||
);
|
||||
assert!(
|
||||
day.len() == 2 && day.chars().all(|c| c.is_ascii_digit()),
|
||||
"Day dir not zero-padded 2-digit numeric: {day}"
|
||||
);
|
||||
if let Ok(m) = month.parse::<u8>() {
|
||||
assert!((1..=12).contains(&m), "Month out of range: {m}");
|
||||
}
|
||||
if let Ok(d) = day.parse::<u8>() {
|
||||
assert!((1..=31).contains(&d), "Day out of range: {d}");
|
||||
}
|
||||
|
||||
let content =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| panic!("Failed to read session file"));
|
||||
let mut lines = content.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or("missing session meta line")
|
||||
.unwrap_or_else(|_| panic!("missing session meta line"));
|
||||
let meta: serde_json::Value = serde_json::from_str(meta_line)
|
||||
.unwrap_or_else(|_| panic!("Failed to parse session meta line as JSON"));
|
||||
assert!(meta.get("id").is_some(), "SessionMeta missing id");
|
||||
assert!(
|
||||
meta.get("timestamp").is_some(),
|
||||
"SessionMeta missing timestamp"
|
||||
);
|
||||
|
||||
let mut found_message = false;
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(item) = serde_json::from_str::<serde_json::Value>(line) else {
|
||||
continue;
|
||||
};
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
|
||||
if let Some(c) = item.get("content") {
|
||||
if c.to_string().contains(&marker) {
|
||||
found_message = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
found_message,
|
||||
"No message found in session file containing the marker"
|
||||
);
|
||||
|
||||
// Second run: resume and append.
|
||||
let orig_len = content.lines().count();
|
||||
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
|
||||
let prompt2 = format!("echo {marker2}");
|
||||
// Cross‑platform safe resume override. On Windows, backslashes in a TOML string must be escaped
|
||||
// or the parse will fail and the raw literal (including quotes) may be preserved all the way down
|
||||
// to Config, which in turn breaks resume because the path is invalid. Normalize to forward slashes
|
||||
// to sidestep the issue.
|
||||
let resume_path_str = path.to_string_lossy().replace('\\', "/");
|
||||
let resume_override = format!("experimental_resume=\"{resume_path_str}\"");
|
||||
let mut cmd2 = AssertCommand::new("cargo");
|
||||
cmd2.arg("run")
|
||||
.arg("-p")
|
||||
.arg("codex-cli")
|
||||
.arg("--quiet")
|
||||
.arg("--")
|
||||
.arg("exec")
|
||||
.arg("--skip-git-repo-check")
|
||||
.arg("-c")
|
||||
.arg(&resume_override)
|
||||
.arg("-C")
|
||||
.arg(env!("CARGO_MANIFEST_DIR"))
|
||||
.arg(&prompt2);
|
||||
cmd2.env("CODEX_HOME", home.path())
|
||||
.env("OPENAI_API_KEY", "dummy")
|
||||
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||
.env("OPENAI_BASE_URL", "http://unused.local");
|
||||
let output2 = cmd2.output().unwrap();
|
||||
assert!(output2.status.success(), "resume codex-cli run failed");
|
||||
|
||||
// The rollout writer runs on a background async task; give it a moment to flush.
|
||||
let mut new_len = orig_len;
|
||||
let deadline = Instant::now() + Duration::from_secs(5);
|
||||
let mut content2 = String::new();
|
||||
while Instant::now() < deadline {
|
||||
if let Ok(c) = std::fs::read_to_string(&path) {
|
||||
let count = c.lines().count();
|
||||
if count > orig_len {
|
||||
content2 = c;
|
||||
new_len = count;
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
if content2.is_empty() {
|
||||
// last attempt
|
||||
content2 = std::fs::read_to_string(&path).unwrap();
|
||||
new_len = content2.lines().count();
|
||||
}
|
||||
assert!(new_len > orig_len, "rollout file did not grow after resume");
|
||||
assert!(content2.contains(&marker), "rollout lost original marker");
|
||||
assert!(
|
||||
content2.contains(&marker2),
|
||||
"rollout missing resumed marker"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
use codex_core::Codex;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
mod test_support;
|
||||
use tempfile::TempDir;
|
||||
use test_support::load_default_config_for_test;
|
||||
use test_support::load_sse_fixture_with_id;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_session_id_and_model_headers_in_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[("originator".to_string(), "codex_cli_rs".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
// Init session
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) =
|
||||
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_)))
|
||||
.await
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let current_session_id = Some(session_id.to_string());
|
||||
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.headers.get("session_id").unwrap();
|
||||
let originator = request.headers.get("originator").unwrap();
|
||||
|
||||
assert!(current_session_id.is_some());
|
||||
assert_eq!(
|
||||
request_body.to_str().unwrap(),
|
||||
current_session_id.as_ref().unwrap()
|
||||
);
|
||||
assert_eq!(originator.to_str().unwrap(), "codex_cli_rs");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_base_instructions_override_in_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
|
||||
config.base_instructions = Some("test instructions".to_string());
|
||||
config.model_provider = model_provider;
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, ..) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
assert!(
|
||||
request_body["instructions"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("test instructions")
|
||||
);
|
||||
}
|
||||
@@ -45,12 +45,23 @@ async fn spawn_codex() -> Result<Codex, CodexErr> {
|
||||
"OPENAI_API_KEY must be set for live tests"
|
||||
);
|
||||
|
||||
// Environment tweaks to keep the tests snappy and inexpensive while still
|
||||
// exercising retry/robustness logic.
|
||||
//
|
||||
// NOTE: Starting with the 2024 edition `std::env::set_var` is `unsafe`
|
||||
// because changing the process environment races with any other threads
|
||||
// that might be performing environment look-ups at the same time.
|
||||
// Restrict the unsafety to this tiny block that happens at the very
|
||||
// beginning of the test, before we spawn any background tasks that could
|
||||
// observe the environment.
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "2");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "2");
|
||||
}
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider.request_max_retries = Some(2);
|
||||
config.model_provider.stream_max_retries = Some(2);
|
||||
let (agent, _init_id, _session_id) =
|
||||
Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?;
|
||||
let config = load_default_config_for_test(&codex_home);
|
||||
let (agent, _init_id) = Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?;
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
@@ -68,7 +79,7 @@ async fn live_streaming_and_prev_id_reset() {
|
||||
|
||||
let codex = spawn_codex().await.unwrap();
|
||||
|
||||
// ---------- Task 1 ----------
|
||||
// ---------- Task 1 ----------
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
@@ -102,7 +113,7 @@ async fn live_streaming_and_prev_id_reset() {
|
||||
"Agent did not stream any AgentMessage before TaskComplete"
|
||||
);
|
||||
|
||||
// ---------- Task 2 (same session) ----------
|
||||
// ---------- Task 2 (same session) ----------
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
|
||||
@@ -88,8 +88,13 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure retry behavior explicitly to avoid mutating process-wide
|
||||
// environment variables.
|
||||
// Environment
|
||||
// Update environment – `set_var` is `unsafe` starting with the 2024
|
||||
// edition so we group the calls into a single `unsafe { … }` block.
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
|
||||
}
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
@@ -102,10 +107,6 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
// disable retries so we don't get duplicate calls in this test
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
// Init session
|
||||
@@ -113,7 +114,7 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
// Task 1 – triggers first request (no previous_response_id)
|
||||
codex
|
||||
|
||||
@@ -32,6 +32,8 @@ fn sse_completed(id: &str) -> String {
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
// this test is flaky (has race conditions), so we ignore it for now
|
||||
#[ignore]
|
||||
async fn retries_on_early_close() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
@@ -70,8 +72,19 @@ async fn retries_on_early_close() {
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Configure retry behavior explicitly to avoid mutating process-wide
|
||||
// environment variables.
|
||||
// Environment
|
||||
//
|
||||
// As of Rust 2024 `std::env::set_var` has been made `unsafe` because
|
||||
// mutating the process environment is inherently racy when other threads
|
||||
// are running. We therefore have to wrap every call in an explicit
|
||||
// `unsafe` block. These are limited to the test-setup section so the
|
||||
// scope is very small and clearly delineated.
|
||||
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1");
|
||||
std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000");
|
||||
}
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
@@ -85,17 +98,13 @@ async fn retries_on_early_close() {
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
// exercise retry path: first attempt yields incomplete stream, so allow 1 retry
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(1),
|
||||
stream_idle_timeout_ms: Some(2000),
|
||||
};
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c).await.unwrap();
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
|
||||
@@ -76,24 +76,3 @@ pub fn load_sse_fixture_with_id(path: impl AsRef<std::path::Path>, id: &str) ->
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn wait_for_event<F>(
|
||||
codex: &codex_core::Codex,
|
||||
mut predicate: F,
|
||||
) -> codex_core::protocol::EventMsg
|
||||
where
|
||||
F: FnMut(&codex_core::protocol::EventMsg) -> bool,
|
||||
{
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("stream ended unexpectedly");
|
||||
if predicate(&ev.msg) {
|
||||
return ev.msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,10 +51,6 @@ pub struct Cli {
|
||||
#[arg(long = "color", value_enum, default_value_t = Color::Auto)]
|
||||
pub color: Color,
|
||||
|
||||
/// Print events to stdout as JSONL.
|
||||
#[arg(long = "json", default_value_t = false)]
|
||||
pub json: bool,
|
||||
|
||||
/// Specifies file where the last message from the agent should be written.
|
||||
#[arg(long = "output-last-message")]
|
||||
pub last_message_file: Option<PathBuf>,
|
||||
|
||||
@@ -1,37 +1,539 @@
|
||||
use codex_common::elapsed::format_elapsed;
|
||||
use codex_common::summarize_sandbox_policy;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::model_supports_reasoning_summaries;
|
||||
use codex_core::protocol::AgentMessageDeltaEvent;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::AgentReasoningDeltaEvent;
|
||||
use codex_core::protocol::BackgroundEventEvent;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecCommandBeginEvent;
|
||||
use codex_core::protocol::ExecCommandEndEvent;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::McpToolCallBeginEvent;
|
||||
use codex_core::protocol::McpToolCallEndEvent;
|
||||
use codex_core::protocol::PatchApplyBeginEvent;
|
||||
use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use std::time::Instant;
|
||||
|
||||
pub(crate) trait EventProcessor {
|
||||
/// Print summary of effective configuration and user prompt.
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str);
|
||||
/// This should be configurable. When used in CI, users may not want to impose
|
||||
/// a limit so they can see the full transcript.
|
||||
const MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL: usize = 20;
|
||||
|
||||
/// Handle a single event emitted by the agent.
|
||||
fn process_event(&mut self, event: Event);
|
||||
pub(crate) struct EventProcessor {
|
||||
call_id_to_command: HashMap<String, ExecCommandBegin>,
|
||||
call_id_to_patch: HashMap<String, PatchApplyBegin>,
|
||||
|
||||
/// Tracks in-flight MCP tool calls so we can calculate duration and print
|
||||
/// a concise summary when the corresponding `McpToolCallEnd` event is
|
||||
/// received.
|
||||
call_id_to_tool_call: HashMap<String, McpToolCallBegin>,
|
||||
|
||||
// To ensure that --color=never is respected, ANSI escapes _must_ be added
|
||||
// using .style() with one of these fields. If you need a new style, add a
|
||||
// new field here.
|
||||
bold: Style,
|
||||
italic: Style,
|
||||
dimmed: Style,
|
||||
|
||||
magenta: Style,
|
||||
red: Style,
|
||||
green: Style,
|
||||
cyan: Style,
|
||||
|
||||
/// Whether to include `AgentReasoning` events in the output.
|
||||
show_agent_reasoning: bool,
|
||||
answer_started: bool,
|
||||
reasoning_started: bool,
|
||||
}
|
||||
|
||||
pub(crate) fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, String)> {
|
||||
let mut entries = vec![
|
||||
("workdir", config.cwd.display().to_string()),
|
||||
("model", config.model.clone()),
|
||||
("provider", config.model_provider_id.clone()),
|
||||
("approval", format!("{:?}", config.approval_policy)),
|
||||
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
|
||||
];
|
||||
if config.model_provider.wire_api == WireApi::Responses
|
||||
&& model_supports_reasoning_summaries(config)
|
||||
{
|
||||
entries.push((
|
||||
"reasoning effort",
|
||||
config.model_reasoning_effort.to_string(),
|
||||
));
|
||||
entries.push((
|
||||
"reasoning summaries",
|
||||
config.model_reasoning_summary.to_string(),
|
||||
));
|
||||
impl EventProcessor {
|
||||
pub(crate) fn create_with_ansi(with_ansi: bool, config: &Config) -> Self {
|
||||
let call_id_to_command = HashMap::new();
|
||||
let call_id_to_patch = HashMap::new();
|
||||
let call_id_to_tool_call = HashMap::new();
|
||||
|
||||
if with_ansi {
|
||||
Self {
|
||||
call_id_to_command,
|
||||
call_id_to_patch,
|
||||
bold: Style::new().bold(),
|
||||
italic: Style::new().italic(),
|
||||
dimmed: Style::new().dimmed(),
|
||||
magenta: Style::new().magenta(),
|
||||
red: Style::new().red(),
|
||||
green: Style::new().green(),
|
||||
cyan: Style::new().cyan(),
|
||||
call_id_to_tool_call,
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
call_id_to_command,
|
||||
call_id_to_patch,
|
||||
bold: Style::new(),
|
||||
italic: Style::new(),
|
||||
dimmed: Style::new(),
|
||||
magenta: Style::new(),
|
||||
red: Style::new(),
|
||||
green: Style::new(),
|
||||
cyan: Style::new(),
|
||||
call_id_to_tool_call,
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ExecCommandBegin {
|
||||
command: Vec<String>,
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
/// Metadata captured when an `McpToolCallBegin` event is received.
|
||||
struct McpToolCallBegin {
|
||||
/// Formatted invocation string, e.g. `server.tool({"city":"sf"})`.
|
||||
invocation: String,
|
||||
/// Timestamp when the call started so we can compute duration later.
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
struct PatchApplyBegin {
|
||||
start_time: Instant,
|
||||
auto_approved: bool,
|
||||
}
|
||||
|
||||
// Timestamped println helper. The timestamp is styled with self.dimmed.
|
||||
#[macro_export]
|
||||
macro_rules! ts_println {
|
||||
($self:ident, $($arg:tt)*) => {{
|
||||
let now = chrono::Utc::now();
|
||||
let formatted = now.format("[%Y-%m-%dT%H:%M:%S]");
|
||||
print!("{} ", formatted.style($self.dimmed));
|
||||
println!($($arg)*);
|
||||
}};
|
||||
}
|
||||
|
||||
impl EventProcessor {
|
||||
/// Print a concise summary of the effective configuration that will be used
|
||||
/// for the session. This mirrors the information shown in the TUI welcome
|
||||
/// screen.
|
||||
pub(crate) fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
ts_println!(
|
||||
self,
|
||||
"OpenAI Codex v{} (research preview)\n--------",
|
||||
VERSION
|
||||
);
|
||||
|
||||
let mut entries = vec![
|
||||
("workdir", config.cwd.display().to_string()),
|
||||
("model", config.model.clone()),
|
||||
("provider", config.model_provider_id.clone()),
|
||||
("approval", format!("{:?}", config.approval_policy)),
|
||||
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
|
||||
];
|
||||
if config.model_provider.wire_api == WireApi::Responses
|
||||
&& model_supports_reasoning_summaries(config)
|
||||
{
|
||||
entries.push((
|
||||
"reasoning effort",
|
||||
config.model_reasoning_effort.to_string(),
|
||||
));
|
||||
entries.push((
|
||||
"reasoning summaries",
|
||||
config.model_reasoning_summary.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
for (key, value) in entries {
|
||||
println!("{} {}", format!("{key}:").style(self.bold), value);
|
||||
}
|
||||
|
||||
println!("--------");
|
||||
|
||||
// Echo the prompt that will be sent to the agent so it is visible in the
|
||||
// transcript/logs before any events come in. Note the prompt may have been
|
||||
// read from stdin, so it may not be visible in the terminal otherwise.
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"User instructions:".style(self.bold).style(self.cyan),
|
||||
prompt
|
||||
);
|
||||
}
|
||||
|
||||
entries
|
||||
pub(crate) fn process_event(&mut self, event: Event) {
|
||||
let Event { id: _, msg } = event;
|
||||
match msg {
|
||||
EventMsg::Error(ErrorEvent { message }) => {
|
||||
let prefix = "ERROR:".style(self.red);
|
||||
ts_println!(self, "{prefix} {message}");
|
||||
}
|
||||
EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => {
|
||||
ts_println!(self, "{}", message.style(self.dimmed));
|
||||
}
|
||||
EventMsg::TaskStarted | EventMsg::TaskComplete(_) => {
|
||||
// Ignore.
|
||||
}
|
||||
EventMsg::TokenCount(TokenUsage { total_tokens, .. }) => {
|
||||
ts_println!(self, "tokens used: {total_tokens}");
|
||||
}
|
||||
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
|
||||
if !self.answer_started {
|
||||
ts_println!(self, "{}\n", "codex".style(self.italic).style(self.magenta));
|
||||
self.answer_started = true;
|
||||
}
|
||||
print!("{delta}");
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => {
|
||||
if !self.show_agent_reasoning {
|
||||
return;
|
||||
}
|
||||
if !self.reasoning_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n",
|
||||
"thinking".style(self.italic).style(self.magenta),
|
||||
);
|
||||
self.reasoning_started = true;
|
||||
}
|
||||
print!("{delta}");
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
// if answer_started is false, this means we haven't received any
|
||||
// delta. Thus, we need to print the message as a new answer.
|
||||
if !self.answer_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"codex".style(self.italic).style(self.magenta),
|
||||
message,
|
||||
);
|
||||
} else {
|
||||
println!();
|
||||
self.answer_started = false;
|
||||
}
|
||||
}
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id,
|
||||
command,
|
||||
cwd,
|
||||
}) => {
|
||||
self.call_id_to_command.insert(
|
||||
call_id.clone(),
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {} in {}",
|
||||
"exec".style(self.magenta),
|
||||
escape_command(&command).style(self.bold),
|
||||
cwd.to_string_lossy(),
|
||||
);
|
||||
}
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
}) => {
|
||||
let exec_command = self.call_id_to_command.remove(&call_id);
|
||||
let (duration, call) = if let Some(ExecCommandBegin {
|
||||
command,
|
||||
start_time,
|
||||
}) = exec_command
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("{}", escape_command(&command).style(self.bold)),
|
||||
)
|
||||
} else {
|
||||
("".to_string(), format!("exec('{call_id}')"))
|
||||
};
|
||||
|
||||
let output = if exit_code == 0 { stdout } else { stderr };
|
||||
let truncated_output = output
|
||||
.lines()
|
||||
.take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
match exit_code {
|
||||
0 => {
|
||||
let title = format!("{call} succeeded{duration}:");
|
||||
ts_println!(self, "{}", title.style(self.green));
|
||||
}
|
||||
_ => {
|
||||
let title = format!("{call} exited {exit_code}{duration}:");
|
||||
ts_println!(self, "{}", title.style(self.red));
|
||||
}
|
||||
}
|
||||
println!("{}", truncated_output.style(self.dimmed));
|
||||
}
|
||||
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id,
|
||||
server,
|
||||
tool,
|
||||
arguments,
|
||||
}) => {
|
||||
// Build fully-qualified tool name: server.tool
|
||||
let fq_tool_name = format!("{server}.{tool}");
|
||||
|
||||
// Format arguments as compact JSON so they fit on one line.
|
||||
let args_str = arguments
|
||||
.as_ref()
|
||||
.map(|v: &serde_json::Value| {
|
||||
serde_json::to_string(v).unwrap_or_else(|_| v.to_string())
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let invocation = if args_str.is_empty() {
|
||||
format!("{fq_tool_name}()")
|
||||
} else {
|
||||
format!("{fq_tool_name}({args_str})")
|
||||
};
|
||||
|
||||
self.call_id_to_tool_call.insert(
|
||||
call_id.clone(),
|
||||
McpToolCallBegin {
|
||||
invocation: invocation.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"tool".style(self.magenta),
|
||||
invocation.style(self.bold),
|
||||
);
|
||||
}
|
||||
EventMsg::McpToolCallEnd(tool_call_end_event) => {
|
||||
let is_success = tool_call_end_event.is_success();
|
||||
let McpToolCallEndEvent { call_id, result } = tool_call_end_event;
|
||||
// Retrieve start time and invocation for duration calculation and labeling.
|
||||
let info = self.call_id_to_tool_call.remove(&call_id);
|
||||
|
||||
let (duration, invocation) = if let Some(McpToolCallBegin {
|
||||
invocation,
|
||||
start_time,
|
||||
..
|
||||
}) = info
|
||||
{
|
||||
(format!(" in {}", format_elapsed(start_time)), invocation)
|
||||
} else {
|
||||
(String::new(), format!("tool('{call_id}')"))
|
||||
};
|
||||
|
||||
let status_str = if is_success { "success" } else { "failed" };
|
||||
let title_style = if is_success { self.green } else { self.red };
|
||||
let title = format!("{invocation} {status_str}{duration}:");
|
||||
|
||||
ts_println!(self, "{}", title.style(title_style));
|
||||
|
||||
if let Ok(res) = result {
|
||||
let val: serde_json::Value = res.into();
|
||||
let pretty =
|
||||
serde_json::to_string_pretty(&val).unwrap_or_else(|_| val.to_string());
|
||||
|
||||
for line in pretty.lines().take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL) {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved,
|
||||
changes,
|
||||
}) => {
|
||||
// Store metadata so we can calculate duration later when we
|
||||
// receive the corresponding PatchApplyEnd event.
|
||||
self.call_id_to_patch.insert(
|
||||
call_id.clone(),
|
||||
PatchApplyBegin {
|
||||
start_time: Instant::now(),
|
||||
auto_approved,
|
||||
},
|
||||
);
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} auto_approved={}:",
|
||||
"apply_patch".style(self.magenta),
|
||||
auto_approved,
|
||||
);
|
||||
|
||||
// Pretty-print the patch summary with colored diff markers so
|
||||
// it's easy to scan in the terminal output.
|
||||
for (path, change) in changes.iter() {
|
||||
match change {
|
||||
FileChange::Add { content } => {
|
||||
let header = format!(
|
||||
"{} {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy()
|
||||
);
|
||||
println!("{}", header.style(self.magenta));
|
||||
for line in content.lines() {
|
||||
println!("{}", line.style(self.green));
|
||||
}
|
||||
}
|
||||
FileChange::Delete => {
|
||||
let header = format!(
|
||||
"{} {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy()
|
||||
);
|
||||
println!("{}", header.style(self.magenta));
|
||||
}
|
||||
FileChange::Update {
|
||||
unified_diff,
|
||||
move_path,
|
||||
} => {
|
||||
let header = if let Some(dest) = move_path {
|
||||
format!(
|
||||
"{} {} -> {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy(),
|
||||
dest.to_string_lossy()
|
||||
)
|
||||
} else {
|
||||
format!("{} {}", format_file_change(change), path.to_string_lossy())
|
||||
};
|
||||
println!("{}", header.style(self.magenta));
|
||||
|
||||
// Colorize diff lines. We keep file header lines
|
||||
// (--- / +++) without extra coloring so they are
|
||||
// still readable.
|
||||
for diff_line in unified_diff.lines() {
|
||||
if diff_line.starts_with('+') && !diff_line.starts_with("+++") {
|
||||
println!("{}", diff_line.style(self.green));
|
||||
} else if diff_line.starts_with('-')
|
||||
&& !diff_line.starts_with("---")
|
||||
{
|
||||
println!("{}", diff_line.style(self.red));
|
||||
} else {
|
||||
println!("{diff_line}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
}) => {
|
||||
let patch_begin = self.call_id_to_patch.remove(&call_id);
|
||||
|
||||
// Compute duration and summary label similar to exec commands.
|
||||
let (duration, label) = if let Some(PatchApplyBegin {
|
||||
start_time,
|
||||
auto_approved,
|
||||
}) = patch_begin
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("apply_patch(auto_approved={auto_approved})"),
|
||||
)
|
||||
} else {
|
||||
(String::new(), format!("apply_patch('{call_id}')"))
|
||||
};
|
||||
|
||||
let (exit_code, output, title_style) = if success {
|
||||
(0, stdout, self.green)
|
||||
} else {
|
||||
(1, stderr, self.red)
|
||||
};
|
||||
|
||||
let title = format!("{label} exited {exit_code}{duration}:");
|
||||
ts_println!(self, "{}", title.style(title_style));
|
||||
for line in output.lines() {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
EventMsg::AgentReasoning(agent_reasoning_event) => {
|
||||
if self.show_agent_reasoning {
|
||||
if !self.reasoning_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"codex".style(self.italic).style(self.magenta),
|
||||
agent_reasoning_event.text,
|
||||
);
|
||||
} else {
|
||||
println!();
|
||||
self.reasoning_started = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::SessionConfigured(session_configured_event) => {
|
||||
let SessionConfiguredEvent {
|
||||
session_id,
|
||||
model,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
} = session_configured_event;
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"codex session".style(self.magenta).style(self.bold),
|
||||
session_id.to_string().style(self.dimmed)
|
||||
);
|
||||
|
||||
ts_println!(self, "model: {}", model);
|
||||
println!();
|
||||
}
|
||||
EventMsg::GetHistoryEntryResponse(_) => {
|
||||
// Currently ignored in exec output.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn escape_command(command: &[String]) -> String {
|
||||
try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
|
||||
}
|
||||
|
||||
fn format_file_change(change: &FileChange) -> &'static str {
|
||||
match change {
|
||||
FileChange::Add { .. } => "A",
|
||||
FileChange::Delete => "D",
|
||||
FileChange::Update {
|
||||
move_path: Some(_), ..
|
||||
} => "R",
|
||||
FileChange::Update {
|
||||
move_path: None, ..
|
||||
} => "M",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,588 +0,0 @@
|
||||
use codex_common::elapsed::format_elapsed;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::AgentMessageDeltaEvent;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::AgentReasoningDeltaEvent;
|
||||
use codex_core::protocol::BackgroundEventEvent;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecCommandBeginEvent;
|
||||
use codex_core::protocol::ExecCommandEndEvent;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::McpToolCallBeginEvent;
|
||||
use codex_core::protocol::McpToolCallEndEvent;
|
||||
use codex_core::protocol::PatchApplyBeginEvent;
|
||||
use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use std::time::Instant;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::create_config_summary_entries;
|
||||
|
||||
// Helper: determine base ~/.codex directory similar to concurrent module.
|
||||
fn codex_base_dir_for_logging() -> Option<std::path::PathBuf> {
|
||||
if let Ok(val) = std::env::var("CODEX_HOME") { if !val.is_empty() { return std::fs::canonicalize(val).ok(); } }
|
||||
let home = std::env::var_os("HOME")?;
|
||||
let base = std::path::PathBuf::from(home).join(".codex");
|
||||
let _ = std::fs::create_dir_all(&base);
|
||||
Some(base)
|
||||
}
|
||||
|
||||
fn append_json_line(path: &Path, value: &serde_json::Value) -> std::io::Result<()> {
|
||||
use std::io::Write as _;
|
||||
let mut f = std::fs::OpenOptions::new().create(true).append(true).open(path)?;
|
||||
writeln!(f, "{}", value.to_string())
|
||||
}
|
||||
|
||||
/// This should be configurable. When used in CI, users may not want to impose
|
||||
/// a limit so they can see the full transcript.
|
||||
const MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL: usize = 20;
|
||||
pub(crate) struct EventProcessorWithHumanOutput {
|
||||
call_id_to_command: HashMap<String, ExecCommandBegin>,
|
||||
call_id_to_patch: HashMap<String, PatchApplyBegin>,
|
||||
|
||||
/// Tracks in-flight MCP tool calls so we can calculate duration and print
|
||||
/// a concise summary when the corresponding `McpToolCallEnd` event is
|
||||
/// received.
|
||||
call_id_to_tool_call: HashMap<String, McpToolCallBegin>,
|
||||
|
||||
// To ensure that --color=never is respected, ANSI escapes _must_ be added
|
||||
// using .style() with one of these fields. If you need a new style, add a
|
||||
// new field here.
|
||||
bold: Style,
|
||||
italic: Style,
|
||||
dimmed: Style,
|
||||
|
||||
magenta: Style,
|
||||
red: Style,
|
||||
green: Style,
|
||||
cyan: Style,
|
||||
|
||||
/// Whether to include `AgentReasoning` events in the output.
|
||||
show_agent_reasoning: bool,
|
||||
answer_started: bool,
|
||||
reasoning_started: bool,
|
||||
last_token_usage: Option<TokenUsage>,
|
||||
}
|
||||
|
||||
impl EventProcessorWithHumanOutput {
|
||||
pub(crate) fn create_with_ansi(with_ansi: bool, config: &Config) -> Self {
|
||||
let call_id_to_command = HashMap::new();
|
||||
let call_id_to_patch = HashMap::new();
|
||||
let call_id_to_tool_call = HashMap::new();
|
||||
|
||||
if with_ansi {
|
||||
Self {
|
||||
call_id_to_command,
|
||||
call_id_to_patch,
|
||||
bold: Style::new().bold(),
|
||||
italic: Style::new().italic(),
|
||||
dimmed: Style::new().dimmed(),
|
||||
magenta: Style::new().magenta(),
|
||||
red: Style::new().red(),
|
||||
green: Style::new().green(),
|
||||
cyan: Style::new().cyan(),
|
||||
call_id_to_tool_call,
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
last_token_usage: None,
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
call_id_to_command,
|
||||
call_id_to_patch,
|
||||
bold: Style::new(),
|
||||
italic: Style::new(),
|
||||
dimmed: Style::new(),
|
||||
magenta: Style::new(),
|
||||
red: Style::new(),
|
||||
green: Style::new(),
|
||||
cyan: Style::new(),
|
||||
call_id_to_tool_call,
|
||||
show_agent_reasoning: !config.hide_agent_reasoning,
|
||||
answer_started: false,
|
||||
reasoning_started: false,
|
||||
last_token_usage: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ExecCommandBegin {
|
||||
command: Vec<String>,
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
/// Metadata captured when an `McpToolCallBegin` event is received.
|
||||
struct McpToolCallBegin {
|
||||
/// Formatted invocation string, e.g. `server.tool({"city":"sf"})`.
|
||||
invocation: String,
|
||||
/// Timestamp when the call started so we can compute duration later.
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
struct PatchApplyBegin {
|
||||
start_time: Instant,
|
||||
auto_approved: bool,
|
||||
}
|
||||
|
||||
// Timestamped println helper. The timestamp is styled with self.dimmed.
|
||||
#[macro_export]
|
||||
macro_rules! ts_println {
|
||||
($self:ident, $($arg:tt)*) => {{
|
||||
let now = chrono::Utc::now();
|
||||
let formatted = now.format("[%Y-%m-%dT%H:%M:%S]");
|
||||
print!("{} ", formatted.style($self.dimmed));
|
||||
println!($($arg)*);
|
||||
}};
|
||||
}
|
||||
|
||||
impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
/// Print a concise summary of the effective configuration that will be used
|
||||
/// for the session. This mirrors the information shown in the TUI welcome
|
||||
/// screen.
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
ts_println!(
|
||||
self,
|
||||
"OpenAI Codex v{} (research preview)\n--------",
|
||||
VERSION
|
||||
);
|
||||
|
||||
let entries = create_config_summary_entries(config);
|
||||
|
||||
for (key, value) in entries {
|
||||
println!("{} {}", format!("{key}:").style(self.bold), value);
|
||||
}
|
||||
|
||||
println!("--------");
|
||||
|
||||
// Echo the prompt that will be sent to the agent so it is visible in the
|
||||
// transcript/logs before any events come in. Note the prompt may have been
|
||||
// read from stdin, so it may not be visible in the terminal otherwise.
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"User instructions:".style(self.bold).style(self.cyan),
|
||||
prompt
|
||||
);
|
||||
}
|
||||
|
||||
fn process_event(&mut self, event: Event) {
|
||||
let Event { id: _, msg } = event;
|
||||
match msg {
|
||||
EventMsg::Error(ErrorEvent { message }) => {
|
||||
let prefix = "ERROR:".style(self.red);
|
||||
ts_println!(self, "{prefix} {message}");
|
||||
}
|
||||
EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => {
|
||||
ts_println!(self, "{}", message.style(self.dimmed));
|
||||
}
|
||||
EventMsg::TaskStarted => {
|
||||
// no-op
|
||||
}
|
||||
EventMsg::TaskComplete(_) => {
|
||||
// On completion, append a final state entry with last token count snapshot.
|
||||
if let Ok(task_id) = std::env::var("CODEX_TASK_ID") {
|
||||
if let Some(base) = codex_base_dir_for_logging() {
|
||||
let tasks_path = base.join("tasks.jsonl");
|
||||
let ts = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
let token_json = self.last_token_usage.as_ref().map(|u| serde_json::json!({
|
||||
"input_tokens": u.input_tokens,
|
||||
"cached_input_tokens": u.cached_input_tokens,
|
||||
"output_tokens": u.output_tokens,
|
||||
"reasoning_output_tokens": u.reasoning_output_tokens,
|
||||
"total_tokens": u.total_tokens,
|
||||
}));
|
||||
let mut obj = serde_json::json!({
|
||||
"task_id": task_id,
|
||||
"completion_time": ts,
|
||||
"end_time": ts,
|
||||
"state": "done",
|
||||
});
|
||||
if let Some(tj) = token_json { if let serde_json::Value::Object(ref mut map) = obj { map.insert("token_count".to_string(), tj); } }
|
||||
let _ = append_json_line(&tasks_path, &obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::TokenCount(token_usage_full) => {
|
||||
self.last_token_usage = Some(token_usage_full.clone());
|
||||
ts_println!(self, "tokens used: {}", token_usage_full.total_tokens);
|
||||
if let Ok(task_id) = std::env::var("CODEX_TASK_ID") {
|
||||
if let Some(base) = codex_base_dir_for_logging() {
|
||||
let tasks_path = base.join("tasks.jsonl");
|
||||
let ts = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
let full = serde_json::json!({
|
||||
"task_id": task_id,
|
||||
"update_time": ts,
|
||||
"token_count": {
|
||||
"input_tokens": token_usage_full.input_tokens,
|
||||
"cached_input_tokens": token_usage_full.cached_input_tokens,
|
||||
"output_tokens": token_usage_full.output_tokens,
|
||||
"reasoning_output_tokens": token_usage_full.reasoning_output_tokens,
|
||||
"total_tokens": token_usage_full.total_tokens,
|
||||
}
|
||||
});
|
||||
let _ = append_json_line(&tasks_path, &full);
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
|
||||
if !self.answer_started {
|
||||
ts_println!(self, "{}\n", "codex".style(self.italic).style(self.magenta));
|
||||
self.answer_started = true;
|
||||
}
|
||||
print!("{delta}");
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => {
|
||||
if !self.show_agent_reasoning {
|
||||
return;
|
||||
}
|
||||
if !self.reasoning_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n",
|
||||
"thinking".style(self.italic).style(self.magenta),
|
||||
);
|
||||
self.reasoning_started = true;
|
||||
}
|
||||
print!("{delta}");
|
||||
#[allow(clippy::expect_used)]
|
||||
std::io::stdout().flush().expect("could not flush stdout");
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
// if answer_started is false, this means we haven't received any
|
||||
// delta. Thus, we need to print the message as a new answer.
|
||||
if !self.answer_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"codex".style(self.italic).style(self.magenta),
|
||||
message,
|
||||
);
|
||||
} else {
|
||||
println!();
|
||||
self.answer_started = false;
|
||||
}
|
||||
}
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id,
|
||||
command,
|
||||
cwd,
|
||||
}) => {
|
||||
self.call_id_to_command.insert(
|
||||
call_id.clone(),
|
||||
ExecCommandBegin {
|
||||
command: command.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {} in {}",
|
||||
"exec".style(self.magenta),
|
||||
escape_command(&command).style(self.bold),
|
||||
cwd.to_string_lossy(),
|
||||
);
|
||||
}
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
}) => {
|
||||
let exec_command = self.call_id_to_command.remove(&call_id);
|
||||
let (duration, call) = if let Some(ExecCommandBegin {
|
||||
command,
|
||||
start_time,
|
||||
}) = exec_command
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("{}", escape_command(&command).style(self.bold)),
|
||||
)
|
||||
} else {
|
||||
("".to_string(), format!("exec('{call_id}')"))
|
||||
};
|
||||
|
||||
let output = if exit_code == 0 { stdout } else { stderr };
|
||||
let truncated_output = output
|
||||
.lines()
|
||||
.take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
match exit_code {
|
||||
0 => {
|
||||
let title = format!("{call} succeeded{duration}:");
|
||||
ts_println!(self, "{}", title.style(self.green));
|
||||
}
|
||||
_ => {
|
||||
let title = format!("{call} exited {exit_code}{duration}:");
|
||||
ts_println!(self, "{}", title.style(self.red));
|
||||
}
|
||||
}
|
||||
println!("{}", truncated_output.style(self.dimmed));
|
||||
}
|
||||
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id,
|
||||
server,
|
||||
tool,
|
||||
arguments,
|
||||
}) => {
|
||||
// Build fully-qualified tool name: server.tool
|
||||
let fq_tool_name = format!("{server}.{tool}");
|
||||
|
||||
// Format arguments as compact JSON so they fit on one line.
|
||||
let args_str = arguments
|
||||
.as_ref()
|
||||
.map(|v: &serde_json::Value| {
|
||||
serde_json::to_string(v).unwrap_or_else(|_| v.to_string())
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let invocation = if args_str.is_empty() {
|
||||
format!("{fq_tool_name}()")
|
||||
} else {
|
||||
format!("{fq_tool_name}({args_str})")
|
||||
};
|
||||
|
||||
self.call_id_to_tool_call.insert(
|
||||
call_id.clone(),
|
||||
McpToolCallBegin {
|
||||
invocation: invocation.clone(),
|
||||
start_time: Instant::now(),
|
||||
},
|
||||
);
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"tool".style(self.magenta),
|
||||
invocation.style(self.bold),
|
||||
);
|
||||
}
|
||||
EventMsg::McpToolCallEnd(tool_call_end_event) => {
|
||||
let is_success = tool_call_end_event.is_success();
|
||||
let McpToolCallEndEvent { call_id, result } = tool_call_end_event;
|
||||
// Retrieve start time and invocation for duration calculation and labeling.
|
||||
let info = self.call_id_to_tool_call.remove(&call_id);
|
||||
|
||||
let (duration, invocation) = if let Some(McpToolCallBegin {
|
||||
invocation,
|
||||
start_time,
|
||||
..
|
||||
}) = info
|
||||
{
|
||||
(format!(" in {}", format_elapsed(start_time)), invocation)
|
||||
} else {
|
||||
(String::new(), format!("tool('{call_id}')"))
|
||||
};
|
||||
|
||||
let status_str = if is_success { "success" } else { "failed" };
|
||||
let title_style = if is_success { self.green } else { self.red };
|
||||
let title = format!("{invocation} {status_str}{duration}:");
|
||||
|
||||
ts_println!(self, "{}", title.style(title_style));
|
||||
|
||||
if let Ok(res) = result {
|
||||
let val: serde_json::Value = res.into();
|
||||
let pretty =
|
||||
serde_json::to_string_pretty(&val).unwrap_or_else(|_| val.to_string());
|
||||
|
||||
for line in pretty.lines().take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL) {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved,
|
||||
changes,
|
||||
}) => {
|
||||
// Store metadata so we can calculate duration later when we
|
||||
// receive the corresponding PatchApplyEnd event.
|
||||
self.call_id_to_patch.insert(
|
||||
call_id.clone(),
|
||||
PatchApplyBegin {
|
||||
start_time: Instant::now(),
|
||||
auto_approved,
|
||||
},
|
||||
);
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} auto_approved={}:",
|
||||
"apply_patch".style(self.magenta),
|
||||
auto_approved,
|
||||
);
|
||||
|
||||
// Pretty-print the patch summary with colored diff markers so
|
||||
// it's easy to scan in the terminal output.
|
||||
for (path, change) in changes.iter() {
|
||||
match change {
|
||||
FileChange::Add { content } => {
|
||||
let header = format!(
|
||||
"{} {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy()
|
||||
);
|
||||
println!("{}", header.style(self.magenta));
|
||||
for line in content.lines() {
|
||||
println!("{}", line.style(self.green));
|
||||
}
|
||||
}
|
||||
FileChange::Delete => {
|
||||
let header = format!(
|
||||
"{} {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy()
|
||||
);
|
||||
println!("{}", header.style(self.magenta));
|
||||
}
|
||||
FileChange::Update {
|
||||
unified_diff,
|
||||
move_path,
|
||||
} => {
|
||||
let header = if let Some(dest) = move_path {
|
||||
format!(
|
||||
"{} {} -> {}",
|
||||
format_file_change(change),
|
||||
path.to_string_lossy(),
|
||||
dest.to_string_lossy()
|
||||
)
|
||||
} else {
|
||||
format!("{} {}", format_file_change(change), path.to_string_lossy())
|
||||
};
|
||||
println!("{}", header.style(self.magenta));
|
||||
|
||||
// Colorize diff lines. We keep file header lines
|
||||
// (--- / +++) without extra coloring so they are
|
||||
// still readable.
|
||||
for diff_line in unified_diff.lines() {
|
||||
if diff_line.starts_with('+') && !diff_line.starts_with("+++") {
|
||||
println!("{}", diff_line.style(self.green));
|
||||
} else if diff_line.starts_with('-')
|
||||
&& !diff_line.starts_with("---")
|
||||
{
|
||||
println!("{}", diff_line.style(self.red));
|
||||
} else {
|
||||
println!("{diff_line}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
}) => {
|
||||
let patch_begin = self.call_id_to_patch.remove(&call_id);
|
||||
|
||||
// Compute duration and summary label similar to exec commands.
|
||||
let (duration, label) = if let Some(PatchApplyBegin {
|
||||
start_time,
|
||||
auto_approved,
|
||||
}) = patch_begin
|
||||
{
|
||||
(
|
||||
format!(" in {}", format_elapsed(start_time)),
|
||||
format!("apply_patch(auto_approved={auto_approved})"),
|
||||
)
|
||||
} else {
|
||||
(String::new(), format!("apply_patch('{call_id}')"))
|
||||
};
|
||||
|
||||
let (exit_code, output, title_style) = if success {
|
||||
(0, stdout, self.green)
|
||||
} else {
|
||||
(1, stderr, self.red)
|
||||
};
|
||||
|
||||
let title = format!("{label} exited {exit_code}{duration}:");
|
||||
ts_println!(self, "{}", title.style(title_style));
|
||||
for line in output.lines() {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
EventMsg::AgentReasoning(agent_reasoning_event) => {
|
||||
if self.show_agent_reasoning {
|
||||
if !self.reasoning_started {
|
||||
ts_println!(
|
||||
self,
|
||||
"{}\n{}",
|
||||
"codex".style(self.italic).style(self.magenta),
|
||||
agent_reasoning_event.text,
|
||||
);
|
||||
} else {
|
||||
println!();
|
||||
self.reasoning_started = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
EventMsg::SessionConfigured(session_configured_event) => {
|
||||
let SessionConfiguredEvent {
|
||||
session_id,
|
||||
model,
|
||||
history_log_id: _,
|
||||
history_entry_count: _,
|
||||
} = session_configured_event;
|
||||
|
||||
ts_println!(
|
||||
self,
|
||||
"{} {}",
|
||||
"codex session".style(self.magenta).style(self.bold),
|
||||
session_id.to_string().style(self.dimmed)
|
||||
);
|
||||
|
||||
ts_println!(self, "model: {}", model);
|
||||
println!();
|
||||
}
|
||||
EventMsg::GetHistoryEntryResponse(_) => {
|
||||
// Currently ignored in exec output.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn escape_command(command: &[String]) -> String {
|
||||
try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
|
||||
}
|
||||
|
||||
fn format_file_change(change: &FileChange) -> &'static str {
|
||||
match change {
|
||||
FileChange::Add { .. } => "A",
|
||||
FileChange::Delete => "D",
|
||||
FileChange::Update {
|
||||
move_path: Some(_), ..
|
||||
} => "R",
|
||||
FileChange::Update {
|
||||
move_path: None, ..
|
||||
} => "M",
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::event_processor::EventProcessor;
|
||||
use crate::event_processor::create_config_summary_entries;
|
||||
|
||||
pub(crate) struct EventProcessorWithJsonOutput;
|
||||
|
||||
impl EventProcessorWithJsonOutput {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl EventProcessor for EventProcessorWithJsonOutput {
|
||||
fn print_config_summary(&mut self, config: &Config, prompt: &str) {
|
||||
let entries = create_config_summary_entries(config)
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.to_string(), value))
|
||||
.collect::<HashMap<String, String>>();
|
||||
#[allow(clippy::expect_used)]
|
||||
let config_json =
|
||||
serde_json::to_string(&entries).expect("Failed to serialize config summary to JSON");
|
||||
println!("{config_json}");
|
||||
|
||||
let prompt_json = json!({
|
||||
"prompt": prompt,
|
||||
});
|
||||
println!("{prompt_json}");
|
||||
}
|
||||
|
||||
fn process_event(&mut self, event: Event) {
|
||||
match event.msg {
|
||||
EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_) => {
|
||||
// Suppress streaming events in JSON mode.
|
||||
}
|
||||
_ => {
|
||||
if let Ok(line) = serde_json::to_string(&event) {
|
||||
println!("{line}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
mod cli;
|
||||
mod event_processor;
|
||||
mod event_processor_with_human_output;
|
||||
mod event_processor_with_json_output;
|
||||
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Read;
|
||||
@@ -21,15 +19,12 @@ use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::util::is_inside_git_repo;
|
||||
use event_processor_with_human_output::EventProcessorWithHumanOutput;
|
||||
use event_processor_with_json_output::EventProcessorWithJsonOutput;
|
||||
use event_processor::EventProcessor;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use crate::event_processor::EventProcessor;
|
||||
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let Cli {
|
||||
images,
|
||||
@@ -41,7 +36,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
skip_git_repo_check,
|
||||
color,
|
||||
last_message_file,
|
||||
json: json_mode,
|
||||
sandbox_mode: sandbox_mode_cli_arg,
|
||||
prompt,
|
||||
config_overrides,
|
||||
@@ -110,7 +104,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
model_provider: None,
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: None,
|
||||
};
|
||||
// Parse `-c` overrides.
|
||||
let cli_kv_overrides = match config_overrides.parse_overrides() {
|
||||
@@ -122,15 +115,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
|
||||
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
|
||||
Box::new(EventProcessorWithJsonOutput::new())
|
||||
} else {
|
||||
Box::new(EventProcessorWithHumanOutput::create_with_ansi(
|
||||
stdout_with_ansi,
|
||||
&config,
|
||||
))
|
||||
};
|
||||
|
||||
let mut event_processor = EventProcessor::create_with_ansi(stdout_with_ansi, &config);
|
||||
// Print the effective configuration and prompt so users can see what Codex
|
||||
// is using.
|
||||
event_processor.print_config_summary(&config, &prompt);
|
||||
@@ -154,7 +139,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
.with_writer(std::io::stderr)
|
||||
.try_init();
|
||||
|
||||
let (codex_wrapper, event, ctrl_c, _session_id) = codex_wrapper::init_codex(config).await?;
|
||||
let (codex_wrapper, event, ctrl_c) = codex_wrapper::init_codex(config).await?;
|
||||
let codex = Arc::new(codex_wrapper);
|
||||
info!("Codex initialized with event: {event:?}");
|
||||
|
||||
@@ -237,13 +222,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
}
|
||||
}
|
||||
|
||||
// If running in concurrent auto-merge mode, attempt to commit and merge original branch.
|
||||
if std::env::var("CODEX_CONCURRENT_AUTOMERGE").ok().as_deref() == Some("1") {
|
||||
if let Err(e) = auto_commit_and_fast_forward_original_branch() {
|
||||
eprintln!("[codex-concurrent] Auto-merge skipped: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -268,88 +246,3 @@ fn handle_last_message(
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Auto-commit changes in the concurrent worktree branch and integrate them back into the original branch.
|
||||
/// Strategy:
|
||||
/// 1. Commit any pending changes on the concurrent branch.
|
||||
/// 2. Checkout the original branch in the original root and perform a --no-ff merge.
|
||||
/// Safety: Only performs merge operations if repository state allows; on conflicts it aborts and reports.
|
||||
fn auto_commit_and_fast_forward_original_branch() -> anyhow::Result<()> {
|
||||
use std::process::Command;
|
||||
let concurrent_branch = std::env::var("CODEX_CONCURRENT_BRANCH").ok().ok_or_else(|| anyhow::anyhow!("missing concurrent branch env"))?;
|
||||
let original_branch = std::env::var("CODEX_ORIGINAL_BRANCH").ok().ok_or_else(|| anyhow::anyhow!("missing original branch env"))?;
|
||||
let original_commit = std::env::var("CODEX_ORIGINAL_COMMIT").ok().ok_or_else(|| anyhow::anyhow!("missing original commit env"))?;
|
||||
let worktree_dir_env = std::env::var("CODEX_CONCURRENT_WORKTREE").ok();
|
||||
let original_root_env = std::env::var("CODEX_ORIGINAL_ROOT").ok();
|
||||
|
||||
// Determine directory to run git commit for concurrent branch (worktree if provided, else repo root from rev-parse).
|
||||
let worktree_dir = if let Some(wt) = worktree_dir_env.clone() {
|
||||
std::path::PathBuf::from(wt)
|
||||
} else {
|
||||
let repo_root = Command::new("git").args(["rev-parse", "--show-toplevel"]).output()?;
|
||||
if !repo_root.status.success() { anyhow::bail!("not a git repo"); }
|
||||
std::path::PathBuf::from(String::from_utf8_lossy(&repo_root.stdout).trim().to_string())
|
||||
};
|
||||
|
||||
// Commit pending changes (git add ., git commit -m ...).
|
||||
let status_out = Command::new("git")
|
||||
.current_dir(&worktree_dir)
|
||||
.args(["status", "--porcelain"]).output()?;
|
||||
if !status_out.status.success() { anyhow::bail!("git status failed"); }
|
||||
if !status_out.stdout.is_empty() {
|
||||
let add_status = Command::new("git")
|
||||
.current_dir(&worktree_dir)
|
||||
.args(["add", "."]).status()?;
|
||||
if !add_status.success() { anyhow::bail!("git add failed"); }
|
||||
let commit_msg = format!("Codex concurrent run auto-commit on branch {concurrent_branch}");
|
||||
let commit_status = Command::new("git")
|
||||
.current_dir(&worktree_dir)
|
||||
.args(["commit", "-m", &commit_msg]).status()?;
|
||||
if !commit_status.success() { anyhow::bail!("git commit failed"); }
|
||||
eprintln!("[codex-concurrent] Created commit in {concurrent_branch}.");
|
||||
} else {
|
||||
eprintln!("[codex-concurrent] No changes to commit in {concurrent_branch}.");
|
||||
}
|
||||
|
||||
// Capture head of concurrent branch (for potential future use / diagnostics).
|
||||
let concurrent_head_out = Command::new("git")
|
||||
.current_dir(&worktree_dir)
|
||||
.args(["rev-parse", &concurrent_branch]).output()?;
|
||||
if !concurrent_head_out.status.success() { anyhow::bail!("failed to rev-parse concurrent branch"); }
|
||||
|
||||
// Determine where to integrate (original root if known, else worktree).
|
||||
let integration_dir = if let Some(root) = original_root_env.clone() { std::path::PathBuf::from(root) } else { worktree_dir.clone() };
|
||||
|
||||
// Checkout original branch.
|
||||
let co_status = Command::new("git")
|
||||
.current_dir(&integration_dir)
|
||||
.args(["checkout", &original_branch])
|
||||
.status()?;
|
||||
if !co_status.success() { anyhow::bail!("git checkout {original_branch} failed in original root"); }
|
||||
|
||||
// Check if concurrent branch already merged (ancestor test).
|
||||
let ancestor_status = Command::new("git")
|
||||
.current_dir(&integration_dir)
|
||||
.args(["merge-base", "--is-ancestor", &concurrent_branch, &original_branch])
|
||||
.status();
|
||||
if let Ok(code) = ancestor_status {
|
||||
if code.success() {
|
||||
eprintln!("[codex-concurrent] {concurrent_branch} already merged into {original_branch}; skipping.");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Perform a --no-ff merge.
|
||||
let merge_msg = format!("Merge concurrent Codex branch {concurrent_branch} (base {original_commit})");
|
||||
let merge_status = Command::new("git")
|
||||
.current_dir(&integration_dir)
|
||||
.args(["merge", "--no-ff", &concurrent_branch, "-m", &merge_msg])
|
||||
.status()?;
|
||||
if !merge_status.success() {
|
||||
let _ = Command::new("git").current_dir(&integration_dir).args(["merge", "--abort"]).status();
|
||||
anyhow::bail!("git merge --no-ff failed (conflicts?)");
|
||||
}
|
||||
eprintln!("[codex-concurrent] Merged {concurrent_branch} into {original_branch} in original root: {}", integration_dir.display());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -23,10 +23,3 @@ file-search *args:
|
||||
# format code
|
||||
fmt:
|
||||
cargo fmt -- --config imports_granularity=Item
|
||||
|
||||
fix:
|
||||
cargo clippy --fix --all-features --tests --allow-dirty
|
||||
|
||||
install:
|
||||
rustup show active-toolchain
|
||||
cargo fetch
|
||||
|
||||
@@ -57,12 +57,10 @@ async fn main() -> Result<()> {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".to_string()),
|
||||
},
|
||||
protocol_version: MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
@@ -22,7 +22,6 @@ mcp-types = { path = "../mcp-types" }
|
||||
schemars = "0.8.22"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
shlex = "1.3.0"
|
||||
toml = "0.9"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
||||
@@ -33,11 +32,6 @@ tokio = { version = "1", features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
pretty_assertions = "1.4.1"
|
||||
tempfile = "3"
|
||||
tokio-test = "0.4"
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -7,16 +7,15 @@ use mcp_types::ToolInputSchema;
|
||||
use schemars::JsonSchema;
|
||||
use schemars::r#gen::SchemaSettings;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::json_to_toml::json_to_toml;
|
||||
|
||||
/// Client-supplied configuration for a `codex` tool-call.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[derive(Debug, Clone, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct CodexToolCallParam {
|
||||
pub(crate) struct CodexToolCallParam {
|
||||
/// The *initial user prompt* to start the Codex conversation.
|
||||
pub prompt: String,
|
||||
|
||||
@@ -46,17 +45,13 @@ pub struct CodexToolCallParam {
|
||||
/// CODEX_HOME/config.toml.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub config: Option<HashMap<String, serde_json::Value>>,
|
||||
|
||||
/// The set of instructions to use instead of the default ones.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub base_instructions: Option<String>,
|
||||
}
|
||||
|
||||
/// Custom enum mirroring [`AskForApproval`], but has an extra dependency on
|
||||
/// [`JsonSchema`].
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Clone, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum CodexToolCallApprovalPolicy {
|
||||
pub(crate) enum CodexToolCallApprovalPolicy {
|
||||
Untrusted,
|
||||
OnFailure,
|
||||
Never,
|
||||
@@ -74,9 +69,9 @@ impl From<CodexToolCallApprovalPolicy> for AskForApproval {
|
||||
|
||||
/// Custom enum mirroring [`SandboxMode`] from config_types.rs, but with
|
||||
/// `JsonSchema` support.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Clone, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum CodexToolCallSandboxMode {
|
||||
pub(crate) enum CodexToolCallSandboxMode {
|
||||
ReadOnly,
|
||||
WorkspaceWrite,
|
||||
DangerFullAccess,
|
||||
@@ -113,10 +108,7 @@ pub(crate) fn create_tool_for_codex_tool_call_param() -> Tool {
|
||||
|
||||
Tool {
|
||||
name: "codex".to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
input_schema: tool_input_schema,
|
||||
// TODO(mbolin): This should be defined.
|
||||
output_schema: None,
|
||||
description: Some(
|
||||
"Run a Codex session. Accepts configuration parameters matching the Codex Config struct.".to_string(),
|
||||
),
|
||||
@@ -139,7 +131,6 @@ impl CodexToolCallParam {
|
||||
approval_policy,
|
||||
sandbox,
|
||||
config: cli_overrides,
|
||||
base_instructions,
|
||||
} = self;
|
||||
|
||||
// Build the `ConfigOverrides` recognised by codex-core.
|
||||
@@ -151,7 +142,6 @@ impl CodexToolCallParam {
|
||||
sandbox_mode: sandbox.map(Into::into),
|
||||
model_provider: None,
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions,
|
||||
};
|
||||
|
||||
let cli_overrides = cli_overrides
|
||||
@@ -166,47 +156,6 @@ impl CodexToolCallParam {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub(crate) struct CodexToolCallReplyParam {
|
||||
/// The *session id* for this conversation.
|
||||
pub session_id: String,
|
||||
|
||||
/// The *next user prompt* to continue the Codex conversation.
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
/// Builds a `Tool` definition for the `codex-reply` tool-call.
|
||||
pub(crate) fn create_tool_for_codex_tool_call_reply_param() -> Tool {
|
||||
let schema = SchemaSettings::draft2019_09()
|
||||
.with(|s| {
|
||||
s.inline_subschemas = true;
|
||||
s.option_add_null_type = false;
|
||||
})
|
||||
.into_generator()
|
||||
.into_root_schema_for::<CodexToolCallReplyParam>();
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema_value =
|
||||
serde_json::to_value(&schema).expect("Codex reply tool schema should serialise to JSON");
|
||||
|
||||
let tool_input_schema =
|
||||
serde_json::from_value::<ToolInputSchema>(schema_value).unwrap_or_else(|e| {
|
||||
panic!("failed to create Tool from schema: {e}");
|
||||
});
|
||||
|
||||
Tool {
|
||||
name: "codex-reply".to_string(),
|
||||
title: Some("Codex Reply".to_string()),
|
||||
input_schema: tool_input_schema,
|
||||
output_schema: None,
|
||||
description: Some(
|
||||
"Continue a Codex session by providing the session id and prompt.".to_string(),
|
||||
),
|
||||
annotations: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -230,7 +179,6 @@ mod tests {
|
||||
let tool_json = serde_json::to_value(&tool).expect("tool serializes");
|
||||
let expected_tool_json = serde_json::json!({
|
||||
"name": "codex",
|
||||
"title": "Codex",
|
||||
"description": "Run a Codex session. Accepts configuration parameters matching the Codex Config struct.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
@@ -274,10 +222,6 @@ mod tests {
|
||||
"description": "The *initial user prompt* to start the Codex conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
"base-instructions": {
|
||||
"description": "The set of instructions to use instead of the default ones.",
|
||||
"type": "string"
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"prompt"
|
||||
@@ -286,34 +230,4 @@ mod tests {
|
||||
});
|
||||
assert_eq!(expected_tool_json, tool_json);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_codex_tool_reply_json_schema() {
|
||||
let tool = create_tool_for_codex_tool_call_reply_param();
|
||||
#[expect(clippy::expect_used)]
|
||||
let tool_json = serde_json::to_value(&tool).expect("tool serializes");
|
||||
let expected_tool_json = serde_json::json!({
|
||||
"description": "Continue a Codex session by providing the session id and prompt.",
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"description": "The *next user prompt* to continue the Codex conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
"sessionId": {
|
||||
"description": "The *session id* for this conversation.",
|
||||
"type": "string"
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"prompt",
|
||||
"sessionId",
|
||||
],
|
||||
"type": "object",
|
||||
},
|
||||
"name": "codex-reply",
|
||||
"title": "Codex Reply",
|
||||
});
|
||||
assert_eq!(expected_tool_json, tool_json);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,32 +2,33 @@
|
||||
//! Tokio task. Separated from `message_processor.rs` to keep that file small
|
||||
//! and to make future feature-growth easier to manage.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::CallToolResultContent;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
|
||||
pub(crate) const INVALID_PARAMS_ERROR_CODE: i64 = -32602;
|
||||
/// Convert a Codex [`Event`] to an MCP notification.
|
||||
fn codex_event_to_notification(event: &Event) -> JSONRPCMessage {
|
||||
#[expect(clippy::expect_used)]
|
||||
JSONRPCMessage::Notification(mcp_types::JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method: "codex/event".into(),
|
||||
params: Some(serde_json::to_value(event).expect("Event must serialize")),
|
||||
})
|
||||
}
|
||||
|
||||
/// Run a complete Codex session and stream events back to the client.
|
||||
///
|
||||
@@ -37,34 +38,34 @@ pub async fn run_codex_tool_session(
|
||||
id: RequestId,
|
||||
initial_prompt: String,
|
||||
config: CodexConfig,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
outgoing: Sender<JSONRPCMessage>,
|
||||
) {
|
||||
let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await {
|
||||
let (codex, first_event, _ctrl_c) = match init_codex(config).await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Failed to start Codex session: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing.send_response(id.clone(), result.into()).await;
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let codex = Arc::new(codex);
|
||||
|
||||
// update the session map so we can retrieve the session in a reply, and then drop it, since
|
||||
// we no longer need it for this function
|
||||
session_map.lock().await.insert(session_id, codex.clone());
|
||||
drop(session_map);
|
||||
|
||||
// Send initial SessionConfigured event.
|
||||
outgoing.send_event_as_notification(&first_event).await;
|
||||
let _ = outgoing
|
||||
.send(codex_event_to_notification(&first_event))
|
||||
.await;
|
||||
|
||||
// Use the original MCP request ID as the `sub_id` for the Codex submission so that
|
||||
// any events emitted for this tool-call can be correlated with the
|
||||
@@ -75,7 +76,7 @@ pub async fn run_codex_tool_session(
|
||||
};
|
||||
|
||||
let submission = Submission {
|
||||
id: sub_id.clone(),
|
||||
id: sub_id,
|
||||
op: Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: initial_prompt.clone(),
|
||||
@@ -87,96 +88,83 @@ pub async fn run_codex_tool_session(
|
||||
tracing::error!("Failed to submit initial prompt: {e}");
|
||||
}
|
||||
|
||||
run_codex_tool_session_inner(codex, outgoing, id).await;
|
||||
}
|
||||
|
||||
pub async fn run_codex_tool_session_reply(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
prompt: String,
|
||||
) {
|
||||
if let Err(e) = codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text { text: prompt }],
|
||||
})
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to submit user input: {e}");
|
||||
}
|
||||
|
||||
run_codex_tool_session_inner(codex, outgoing, request_id).await;
|
||||
}
|
||||
|
||||
async fn run_codex_tool_session_inner(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
|
||||
// Stream events until the task needs to pause for user interaction or
|
||||
// completes.
|
||||
loop {
|
||||
match codex.next_event().await {
|
||||
Ok(event) => {
|
||||
outgoing.send_event_as_notification(&event).await;
|
||||
let _ = outgoing.send(codex_event_to_notification(&event)).await;
|
||||
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
command,
|
||||
cwd,
|
||||
reason: _,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
match &event.msg {
|
||||
EventMsg::AgentMessage(AgentMessageEvent { message }) => {
|
||||
last_agent_message = Some(message.clone());
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
let text = match last_agent_message {
|
||||
Some(msg) => msg.clone(),
|
||||
None => "".to_string(),
|
||||
};
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text,
|
||||
text: "EXEC_APPROVAL_REQUIRED".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing
|
||||
.send_response(request_id.clone(), result.into())
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: "PATCH_APPROVAL_REQUIRED".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
};
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: _,
|
||||
}) => {
|
||||
let result = if let Some(msg) = last_agent_message {
|
||||
CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: msg,
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
}
|
||||
} else {
|
||||
CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: String::new(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
}
|
||||
};
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
@@ -189,9 +177,6 @@ async fn run_codex_tool_session_inner(
|
||||
EventMsg::AgentReasoningDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::TaskStarted
|
||||
| EventMsg::TokenCount(_)
|
||||
@@ -215,18 +200,19 @@ async fn run_codex_tool_session_inner(
|
||||
}
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Codex runtime error: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
// TODO(mbolin): Could present the error in a more
|
||||
// structured way.
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing
|
||||
.send_response(request_id.clone(), result.into())
|
||||
let _ = outgoing
|
||||
.send(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: id.clone(),
|
||||
result: result.into(),
|
||||
}))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE;
|
||||
|
||||
/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the
|
||||
/// `params` field of an [`ElicitRequest`].
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecApprovalElicitRequestParams {
|
||||
// These fields are required so that `params`
|
||||
// conforms to ElicitRequestParams.
|
||||
pub message: String,
|
||||
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
|
||||
// These are additional fields the client can use to
|
||||
// correlate the request with the codex tool call.
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
pub codex_command: Vec<String>,
|
||||
pub codex_cwd: PathBuf,
|
||||
}
|
||||
|
||||
// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See:
|
||||
// - https://github.com/modelcontextprotocol/modelcontextprotocol/blob/f962dc1780fa5eed7fb7c8a0232f1fc83ef220cd/schema/2025-06-18/schema.json#L617-L636
|
||||
// - https://modelcontextprotocol.io/specification/draft/client/elicitation#protocol-messages
|
||||
// It should have "action" and "content" fields.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ExecApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_exec_approval_request(
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
outgoing: Arc<crate::outgoing_message::OutgoingMessageSender>,
|
||||
codex: Arc<Codex>,
|
||||
request_id: RequestId,
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
) {
|
||||
let escaped_command =
|
||||
shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "));
|
||||
let message = format!(
|
||||
"Allow Codex to run `{escaped_command}` in `{cwd}`?",
|
||||
cwd = cwd.to_string_lossy()
|
||||
);
|
||||
|
||||
let params = ExecApprovalElicitRequestParams {
|
||||
message,
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "exec-approval".to_string(),
|
||||
codex_mcp_tool_call_id: tool_call_id.clone(),
|
||||
codex_event_id: event_id.clone(),
|
||||
codex_command: command,
|
||||
codex_cwd: cwd,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
let message = format!("Failed to serialize ExecApprovalElicitRequestParams: {err}");
|
||||
error!("{message}");
|
||||
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id.clone(),
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_PARAMS_ERROR_CODE,
|
||||
message,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let on_response = outgoing
|
||||
.send_request(ElicitRequest::METHOD, Some(params_json))
|
||||
.await;
|
||||
|
||||
// Listen for the response on a separate task so we don't block the main agent loop.
|
||||
{
|
||||
let codex = codex.clone();
|
||||
let event_id = event_id.clone();
|
||||
tokio::spawn(async move {
|
||||
on_exec_approval_response(event_id, on_response, codex).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_exec_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<Codex>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("request failed: {err:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Try to deserialize `value` and then make the appropriate call to `codex`.
|
||||
let response = serde_json::from_value::<ExecApprovalResponse>(value).unwrap_or_else(|err| {
|
||||
error!("failed to deserialize ExecApprovalResponse: {err}");
|
||||
// If we cannot deserialize the response, we deny the request to be
|
||||
// conservative.
|
||||
ExecApprovalResponse {
|
||||
decision: ReviewDecision::Denied,
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: event_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit ExecApproval: {err}");
|
||||
}
|
||||
}
|
||||
@@ -16,21 +16,10 @@ use tracing::info;
|
||||
|
||||
mod codex_tool_config;
|
||||
mod codex_tool_runner;
|
||||
mod exec_approval;
|
||||
mod json_to_toml;
|
||||
mod message_processor;
|
||||
mod outgoing_message;
|
||||
mod patch_approval;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
pub use crate::codex_tool_config::CodexToolCallParam;
|
||||
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
|
||||
pub use crate::exec_approval::ExecApprovalResponse;
|
||||
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
|
||||
pub use crate::patch_approval::PatchApprovalResponse;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage – 128 messages should be
|
||||
@@ -46,7 +35,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
|
||||
// Set up channels.
|
||||
let (incoming_tx, mut incoming_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
|
||||
|
||||
// Task: read from stdin, push to `incoming_tx`.
|
||||
let stdin_reader_handle = tokio::spawn({
|
||||
@@ -74,15 +63,16 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
|
||||
// Task: process incoming messages.
|
||||
let processor_handle = tokio::spawn({
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
let mut processor = MessageProcessor::new(outgoing_message_sender, codex_linux_sandbox_exe);
|
||||
let mut processor = MessageProcessor::new(outgoing_tx.clone(), codex_linux_sandbox_exe);
|
||||
async move {
|
||||
while let Some(msg) = incoming_rx.recv().await {
|
||||
match msg {
|
||||
JSONRPCMessage::Request(r) => processor.process_request(r).await,
|
||||
JSONRPCMessage::Response(r) => processor.process_response(r).await,
|
||||
JSONRPCMessage::Request(r) => processor.process_request(r),
|
||||
JSONRPCMessage::Response(r) => processor.process_response(r),
|
||||
JSONRPCMessage::Notification(n) => processor.process_notification(n),
|
||||
JSONRPCMessage::BatchRequest(b) => processor.process_batch_request(b),
|
||||
JSONRPCMessage::Error(e) => processor.process_error(e),
|
||||
JSONRPCMessage::BatchResponse(b) => processor.process_batch_response(b),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,8 +83,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
||||
// Task: write outgoing messages to stdout.
|
||||
let stdout_writer_handle = tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(outgoing_message) = outgoing_rx.recv().await {
|
||||
let msg: JSONRPCMessage = outgoing_message.into();
|
||||
while let Some(msg) = outgoing_rx.recv().await {
|
||||
match serde_json::to_string(&msg) {
|
||||
Ok(json) => {
|
||||
if let Err(e) = stdout.write_all(json.as_bytes()).await {
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex_tool_config::CodexToolCallParam;
|
||||
use crate::codex_tool_config::CodexToolCallReplyParam;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_param;
|
||||
use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::CallToolResultContent;
|
||||
use mcp_types::ClientRequest;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCBatchRequest;
|
||||
use mcp_types::JSONRPCBatchResponse;
|
||||
use mcp_types::JSONRPCError;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
@@ -26,33 +24,30 @@ use mcp_types::ServerCapabilitiesTools;
|
||||
use mcp_types::ServerNotification;
|
||||
use mcp_types::TextContent;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub(crate) struct MessageProcessor {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
outgoing: mpsc::Sender<JSONRPCMessage>,
|
||||
initialized: bool,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
/// Create a new `MessageProcessor`, retaining a handle to the outgoing
|
||||
/// `Sender` so handlers can enqueue messages to be written to stdout.
|
||||
pub(crate) fn new(
|
||||
outgoing: OutgoingMessageSender,
|
||||
outgoing: mpsc::Sender<JSONRPCMessage>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self {
|
||||
outgoing: Arc::new(outgoing),
|
||||
outgoing,
|
||||
initialized: false,
|
||||
codex_linux_sandbox_exe,
|
||||
session_map: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
pub(crate) fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
// Hold on to the ID so we can respond.
|
||||
let request_id = request.id.clone();
|
||||
|
||||
@@ -67,10 +62,10 @@ impl MessageProcessor {
|
||||
// Dispatch to a dedicated handler for each request type.
|
||||
match client_request {
|
||||
ClientRequest::InitializeRequest(params) => {
|
||||
self.handle_initialize(request_id, params).await;
|
||||
self.handle_initialize(request_id, params);
|
||||
}
|
||||
ClientRequest::PingRequest(params) => {
|
||||
self.handle_ping(request_id, params).await;
|
||||
self.handle_ping(request_id, params);
|
||||
}
|
||||
ClientRequest::ListResourcesRequest(params) => {
|
||||
self.handle_list_resources(params);
|
||||
@@ -94,10 +89,10 @@ impl MessageProcessor {
|
||||
self.handle_get_prompt(params);
|
||||
}
|
||||
ClientRequest::ListToolsRequest(params) => {
|
||||
self.handle_list_tools(request_id, params).await;
|
||||
self.handle_list_tools(request_id, params);
|
||||
}
|
||||
ClientRequest::CallToolRequest(params) => {
|
||||
self.handle_call_tool(request_id, params).await;
|
||||
self.handle_call_tool(request_id, params);
|
||||
}
|
||||
ClientRequest::SetLevelRequest(params) => {
|
||||
self.handle_set_level(params);
|
||||
@@ -109,10 +104,8 @@ impl MessageProcessor {
|
||||
}
|
||||
|
||||
/// Handle a standalone JSON-RPC response originating from the peer.
|
||||
pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) {
|
||||
pub(crate) fn process_response(&mut self, response: JSONRPCResponse) {
|
||||
tracing::info!("<- response: {:?}", response);
|
||||
let JSONRPCResponse { id, result, .. } = response;
|
||||
self.outgoing.notify_client_response(id, result).await
|
||||
}
|
||||
|
||||
/// Handle a fire-and-forget JSON-RPC notification.
|
||||
@@ -152,12 +145,42 @@ impl MessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a batch of requests and/or notifications.
|
||||
pub(crate) fn process_batch_request(&mut self, batch: JSONRPCBatchRequest) {
|
||||
tracing::info!("<- batch request containing {} item(s)", batch.len());
|
||||
for item in batch {
|
||||
match item {
|
||||
mcp_types::JSONRPCBatchRequestItem::JSONRPCRequest(req) => {
|
||||
self.process_request(req);
|
||||
}
|
||||
mcp_types::JSONRPCBatchRequestItem::JSONRPCNotification(note) => {
|
||||
self.process_notification(note);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an error object received from the peer.
|
||||
pub(crate) fn process_error(&mut self, err: JSONRPCError) {
|
||||
tracing::error!("<- error: {:?}", err);
|
||||
}
|
||||
|
||||
async fn handle_initialize(
|
||||
/// Handle a batch of responses/errors.
|
||||
pub(crate) fn process_batch_response(&mut self, batch: JSONRPCBatchResponse) {
|
||||
tracing::info!("<- batch response containing {} item(s)", batch.len());
|
||||
for item in batch {
|
||||
match item {
|
||||
mcp_types::JSONRPCBatchResponseItem::JSONRPCResponse(resp) => {
|
||||
self.process_response(resp);
|
||||
}
|
||||
mcp_types::JSONRPCBatchResponseItem::JSONRPCError(err) => {
|
||||
self.process_error(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_initialize(
|
||||
&mut self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::InitializeRequest as ModelContextProtocolRequest>::Params,
|
||||
@@ -166,12 +189,19 @@ impl MessageProcessor {
|
||||
|
||||
if self.initialized {
|
||||
// Already initialised: send JSON-RPC error response.
|
||||
let error = JSONRPCErrorError {
|
||||
code: -32600, // Invalid Request
|
||||
message: "initialize called more than once".to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(id, error).await;
|
||||
let error_msg = JSONRPCMessage::Error(JSONRPCError {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
error: JSONRPCErrorError {
|
||||
code: -32600, // Invalid Request
|
||||
message: "initialize called more than once".to_string(),
|
||||
data: None,
|
||||
},
|
||||
});
|
||||
|
||||
if let Err(e) = self.outgoing.try_send(error_msg) {
|
||||
tracing::error!("Failed to send initialization error: {e}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -193,34 +223,38 @@ impl MessageProcessor {
|
||||
protocol_version: params.protocol_version.clone(),
|
||||
server_info: mcp_types::Implementation {
|
||||
name: "codex-mcp-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
title: Some("Codex".to_string()),
|
||||
version: mcp_types::MCP_SCHEMA_VERSION.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
self.send_response::<mcp_types::InitializeRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::InitializeRequest>(id, result);
|
||||
}
|
||||
|
||||
async fn send_response<T>(&self, id: RequestId, result: T::Result)
|
||||
fn send_response<T>(&self, id: RequestId, result: T::Result)
|
||||
where
|
||||
T: ModelContextProtocolRequest,
|
||||
{
|
||||
// result has `Serialized` instance so should never fail
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let result = serde_json::to_value(result).unwrap();
|
||||
self.outgoing.send_response(id, result).await;
|
||||
let response = JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result: serde_json::to_value(result).unwrap(),
|
||||
});
|
||||
|
||||
if let Err(e) = self.outgoing.try_send(response) {
|
||||
tracing::error!("Failed to send response: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_ping(
|
||||
fn handle_ping(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::PingRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::info!("ping -> params: {:?}", params);
|
||||
let result = json!({});
|
||||
self.send_response::<mcp_types::PingRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::PingRequest>(id, result);
|
||||
}
|
||||
|
||||
fn handle_list_resources(
|
||||
@@ -273,25 +307,21 @@ impl MessageProcessor {
|
||||
tracing::info!("prompts/get -> params: {:?}", params);
|
||||
}
|
||||
|
||||
async fn handle_list_tools(
|
||||
fn handle_list_tools(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::ListToolsRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
) {
|
||||
tracing::trace!("tools/list -> {params:?}");
|
||||
let result = ListToolsResult {
|
||||
tools: vec![
|
||||
create_tool_for_codex_tool_call_param(),
|
||||
create_tool_for_codex_tool_call_reply_param(),
|
||||
],
|
||||
tools: vec![create_tool_for_codex_tool_call_param()],
|
||||
next_cursor: None,
|
||||
};
|
||||
|
||||
self.send_response::<mcp_types::ListToolsRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::ListToolsRequest>(id, result);
|
||||
}
|
||||
|
||||
async fn handle_call_tool(
|
||||
fn handle_call_tool(
|
||||
&self,
|
||||
id: RequestId,
|
||||
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
|
||||
@@ -299,36 +329,28 @@ impl MessageProcessor {
|
||||
tracing::info!("tools/call -> params: {:?}", params);
|
||||
let CallToolRequestParams { name, arguments } = params;
|
||||
|
||||
match name.as_str() {
|
||||
"codex" => self.handle_tool_call_codex(id, arguments).await,
|
||||
"codex-reply" => {
|
||||
self.handle_tool_call_codex_session_reply(id, arguments)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Unknown tool '{name}'"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
}
|
||||
// We only support the "codex" tool for now.
|
||||
if name != "codex" {
|
||||
// Tool not found – return error result so the LLM can react.
|
||||
let result = CallToolResult {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: format!("Unknown tool '{name}'"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option<serde_json::Value>) {
|
||||
let (initial_prompt, config): (String, CodexConfig) = match arguments {
|
||||
Some(json_val) => match serde_json::from_value::<CodexToolCallParam>(json_val) {
|
||||
Ok(tool_cfg) => match tool_cfg.into_config(self.codex_linux_sandbox_exe.clone()) {
|
||||
Ok(cfg) => cfg,
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!(
|
||||
"Failed to load Codex configuration from overrides: {e}"
|
||||
@@ -336,31 +358,27 @@ impl MessageProcessor {
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse configuration for Codex tool: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
content: vec![CallToolResultContent::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text:
|
||||
"Missing arguments for codex tool-call; the `prompt` field is required."
|
||||
@@ -368,135 +386,21 @@ impl MessageProcessor {
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result)
|
||||
.await;
|
||||
self.send_response::<mcp_types::CallToolRequest>(id, result);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Clone outgoing and session map to move into async task.
|
||||
// Clone outgoing sender to move into async task.
|
||||
let outgoing = self.outgoing.clone();
|
||||
let session_map = self.session_map.clone();
|
||||
|
||||
// Spawn an async task to handle the Codex session so that we do not
|
||||
// block the synchronous message-processing loop.
|
||||
task::spawn(async move {
|
||||
// Run the Codex session and stream events back to the client.
|
||||
crate::codex_tool_runner::run_codex_tool_session(
|
||||
id,
|
||||
initial_prompt,
|
||||
config,
|
||||
outgoing,
|
||||
session_map,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
async fn handle_tool_call_codex_session_reply(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
arguments: Option<serde_json::Value>,
|
||||
) {
|
||||
tracing::info!("tools/call -> params: {:?}", arguments);
|
||||
|
||||
// parse arguments
|
||||
let CodexToolCallReplyParam { session_id, prompt } = match arguments {
|
||||
Some(json_val) => match serde_json::from_value::<CodexToolCallReplyParam>(json_val) {
|
||||
Ok(params) => params,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse Codex tool call reply parameters: {e}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse configuration for Codex tool: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
tracing::error!(
|
||||
"Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required."
|
||||
);
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: "Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required.".to_owned(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let session_id = match Uuid::parse_str(&session_id) {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse session_id: {e}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Failed to parse session_id: {e}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
self.send_response::<mcp_types::CallToolRequest>(request_id, result)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// load codex from session map
|
||||
let session_map_mutex = Arc::clone(&self.session_map);
|
||||
|
||||
// Clone outgoing and session map to move into async task.
|
||||
let outgoing = self.outgoing.clone();
|
||||
|
||||
// Spawn an async task to handle the Codex session so that we do not
|
||||
// block the synchronous message-processing loop.
|
||||
task::spawn(async move {
|
||||
let session_map = session_map_mutex.lock().await;
|
||||
let codex = match session_map.get(&session_id) {
|
||||
Some(codex) => codex,
|
||||
None => {
|
||||
tracing::warn!("Session not found for session_id: {session_id}");
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_owned(),
|
||||
text: format!("Session not found for session_id: {session_id}"),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: Some(true),
|
||||
structured_content: None,
|
||||
};
|
||||
// unwrap_or_default is fine here because we know the result is valid JSON
|
||||
outgoing
|
||||
.send_response(request_id, serde_json::to_value(result).unwrap_or_default())
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
crate::codex_tool_runner::run_codex_tool_session_reply(
|
||||
codex.clone(),
|
||||
outgoing,
|
||||
request_id,
|
||||
prompt.clone(),
|
||||
)
|
||||
.await;
|
||||
crate::codex_tool_runner::run_codex_tool_session(id, initial_prompt, config, outgoing)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_core::protocol::Event;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCError;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::Result;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::warn;
|
||||
|
||||
pub(crate) struct OutgoingMessageSender {
|
||||
next_request_id: AtomicI64,
|
||||
sender: mpsc::Sender<OutgoingMessage>,
|
||||
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
|
||||
}
|
||||
|
||||
impl OutgoingMessageSender {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingMessage>) -> Self {
|
||||
Self {
|
||||
next_request_id: AtomicI64::new(0),
|
||||
sender,
|
||||
request_id_to_callback: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> oneshot::Receiver<Result> {
|
||||
let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let outgoing_message_id = id.clone();
|
||||
let (tx_approve, rx_approve) = oneshot::channel();
|
||||
{
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.insert(id, tx_approve);
|
||||
}
|
||||
|
||||
let outgoing_message = OutgoingMessage::Request(OutgoingRequest {
|
||||
id: outgoing_message_id,
|
||||
method: method.to_string(),
|
||||
params,
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
rx_approve
|
||||
}
|
||||
|
||||
pub(crate) async fn notify_client_response(&self, id: RequestId, result: Result) {
|
||||
let entry = {
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove_entry(&id)
|
||||
};
|
||||
|
||||
match entry {
|
||||
Some((id, sender)) => {
|
||||
if let Err(err) = sender.send(result) {
|
||||
warn!("could not notify callback for {id:?} due to: {err:?}");
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("could not find callback for {id:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response(&self, id: RequestId, result: Result) {
|
||||
let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_event_as_notification(&self, event: &Event) {
|
||||
#[expect(clippy::expect_used)]
|
||||
let params = Some(serde_json::to_value(event).expect("Event must serialize"));
|
||||
let outgoing_message = OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: "codex/event".to_string(),
|
||||
params,
|
||||
});
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
|
||||
let _ = self.sender.send(outgoing_message).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Outgoing message from the server to the client.
|
||||
pub(crate) enum OutgoingMessage {
|
||||
Request(OutgoingRequest),
|
||||
Notification(OutgoingNotification),
|
||||
Response(OutgoingResponse),
|
||||
Error(OutgoingError),
|
||||
}
|
||||
|
||||
impl From<OutgoingMessage> for JSONRPCMessage {
|
||||
fn from(val: OutgoingMessage) -> Self {
|
||||
use OutgoingMessage::*;
|
||||
match val {
|
||||
Request(OutgoingRequest { id, method, params }) => {
|
||||
JSONRPCMessage::Request(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
Notification(OutgoingNotification { method, params }) => {
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
Response(OutgoingResponse { id, result }) => {
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result,
|
||||
})
|
||||
}
|
||||
Error(OutgoingError { id, error }) => JSONRPCMessage::Error(JSONRPCError {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
error,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingRequest {
|
||||
pub id: RequestId,
|
||||
pub method: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingNotification {
|
||||
pub method: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingResponse {
|
||||
pub id: RequestId,
|
||||
pub result: Result,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingError {
|
||||
pub error: JSONRPCErrorError,
|
||||
pub id: RequestId,
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PatchApprovalElicitRequestParams {
|
||||
pub message: String,
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_grant_root: Option<PathBuf>,
|
||||
pub codex_changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct PatchApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_patch_approval_request(
|
||||
reason: Option<String>,
|
||||
grant_root: Option<PathBuf>,
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
codex: Arc<Codex>,
|
||||
request_id: RequestId,
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
) {
|
||||
let mut message_lines = Vec::new();
|
||||
if let Some(r) = &reason {
|
||||
message_lines.push(r.clone());
|
||||
}
|
||||
message_lines.push("Allow Codex to apply proposed code changes?".to_string());
|
||||
|
||||
let params = PatchApprovalElicitRequestParams {
|
||||
message: message_lines.join("\n"),
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "patch-approval".to_string(),
|
||||
codex_mcp_tool_call_id: tool_call_id.clone(),
|
||||
codex_event_id: event_id.clone(),
|
||||
codex_reason: reason,
|
||||
codex_grant_root: grant_root,
|
||||
codex_changes: changes,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
let message = format!("Failed to serialize PatchApprovalElicitRequestParams: {err}");
|
||||
error!("{message}");
|
||||
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id.clone(),
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_PARAMS_ERROR_CODE,
|
||||
message,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let on_response = outgoing
|
||||
.send_request(ElicitRequest::METHOD, Some(params_json))
|
||||
.await;
|
||||
|
||||
// Listen for the response on a separate task so we don't block the main agent loop.
|
||||
{
|
||||
let codex = codex.clone();
|
||||
let event_id = event_id.clone();
|
||||
tokio::spawn(async move {
|
||||
on_patch_approval_response(event_id, on_response, codex).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn on_patch_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<Codex>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("request failed: {err:?}");
|
||||
if let Err(submit_err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id.clone(),
|
||||
decision: ReviewDecision::Denied,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit denied PatchApproval after request failure: {submit_err}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let response = serde_json::from_value::<PatchApprovalResponse>(value).unwrap_or_else(|err| {
|
||||
error!("failed to deserialize PatchApprovalResponse: {err}");
|
||||
PatchApprovalResponse {
|
||||
decision: ReviewDecision::Denied,
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit PatchApproval: {err}");
|
||||
}
|
||||
}
|
||||
@@ -1,440 +0,0 @@
|
||||
mod common;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_mcp_server::ExecApprovalElicitRequestParams;
|
||||
use codex_mcp_server::ExecApprovalResponse;
|
||||
use codex_mcp_server::PatchApprovalElicitRequestParams;
|
||||
use codex_mcp_server::PatchApprovalResponse;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::MockServer;
|
||||
|
||||
use crate::common::McpProcess;
|
||||
use crate::common::create_apply_patch_sse_response;
|
||||
use crate::common::create_final_assistant_message_sse_response;
|
||||
use crate::common::create_mock_chat_completions_server;
|
||||
use crate::common::create_shell_sse_response;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
/// Test that a shell command that is not on the "trusted" list triggers an
|
||||
/// elicitation request to the MCP and that sending the approval runs the
|
||||
/// command, as expected.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shell_command_approval_triggers_elicitation() {
|
||||
if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Apparently `#[tokio::test]` must return `()`, so we create a helper
|
||||
// function that returns `Result` so we can use `?` in favor of `unwrap`.
|
||||
if let Err(err) = shell_command_approval_triggers_elicitation().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
// We use `git init` because it will not be on the "trusted" list.
|
||||
let shell_command = vec!["git".to_string(), "init".to_string()];
|
||||
let workdir_for_shell_function_call = TempDir::new()?;
|
||||
|
||||
let McpHandle {
|
||||
process: mut mcp_process,
|
||||
server: _server,
|
||||
dir: _dir,
|
||||
} = create_mcp_process(vec![
|
||||
create_shell_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(workdir_for_shell_function_call.path()),
|
||||
Some(5_000),
|
||||
"call1234",
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("Enjoy your new git repo!")?,
|
||||
])
|
||||
.await?;
|
||||
|
||||
// Send a "codex" tool request, which should hit the completions endpoint.
|
||||
// In turn, it should reply with a tool call, which the MCP should forward
|
||||
// as an elicitation.
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
prompt: "run `git init`".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let elicitation_request = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_request_message(),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// This is the first request from the server, so the id should be 0 given
|
||||
// how things are currently implemented.
|
||||
let elicitation_request_id = RequestId::Integer(0);
|
||||
let expected_elicitation_request = create_expected_elicitation_request(
|
||||
elicitation_request_id.clone(),
|
||||
shell_command.clone(),
|
||||
workdir_for_shell_function_call.path(),
|
||||
codex_request_id.to_string(),
|
||||
// Internal Codex id: empirically it is 1, but this is
|
||||
// admittedly an internal detail that could change.
|
||||
"1".to_string(),
|
||||
)?;
|
||||
assert_eq!(expected_elicitation_request, elicitation_request);
|
||||
|
||||
// Accept the `git init` request by responding to the elicitation.
|
||||
mcp_process
|
||||
.send_response(
|
||||
elicitation_request_id,
|
||||
serde_json::to_value(ExecApprovalResponse {
|
||||
decision: ReviewDecision::Approved,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Verify the original `codex` tool call completes and that `git init` ran
|
||||
// successfully.
|
||||
let codex_response = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(codex_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "Enjoy your new git repo!",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}),
|
||||
},
|
||||
codex_response
|
||||
);
|
||||
|
||||
assert!(
|
||||
workdir_for_shell_function_call.path().join(".git").is_dir(),
|
||||
".git folder should have been created"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_expected_elicitation_request(
|
||||
elicitation_request_id: RequestId,
|
||||
command: Vec<String>,
|
||||
workdir: &Path,
|
||||
codex_mcp_tool_call_id: String,
|
||||
codex_event_id: String,
|
||||
) -> anyhow::Result<JSONRPCRequest> {
|
||||
let expected_message = format!(
|
||||
"Allow Codex to run `{}` in `{}`?",
|
||||
shlex::try_join(command.iter().map(|s| s.as_ref()))?,
|
||||
workdir.to_string_lossy()
|
||||
);
|
||||
Ok(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: elicitation_request_id,
|
||||
method: ElicitRequest::METHOD.to_string(),
|
||||
params: Some(serde_json::to_value(&ExecApprovalElicitRequestParams {
|
||||
message: expected_message,
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "exec-approval".to_string(),
|
||||
codex_mcp_tool_call_id,
|
||||
codex_event_id,
|
||||
codex_command: command,
|
||||
codex_cwd: workdir.to_path_buf(),
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
/// Test that patch approval triggers an elicitation request to the MCP and that
|
||||
/// sending the approval applies the patch, as expected.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_patch_approval_triggers_elicitation() {
|
||||
if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(err) = patch_approval_triggers_elicitation().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
let cwd = TempDir::new()?;
|
||||
let test_file = cwd.path().join("destination_file.txt");
|
||||
std::fs::write(&test_file, "original content\n")?;
|
||||
|
||||
let patch_content = format!(
|
||||
"*** Begin Patch\n*** Update File: {}\n-original content\n+modified content\n*** End Patch",
|
||||
test_file.as_path().to_string_lossy()
|
||||
);
|
||||
|
||||
let McpHandle {
|
||||
process: mut mcp_process,
|
||||
server: _server,
|
||||
dir: _dir,
|
||||
} = create_mcp_process(vec![
|
||||
create_apply_patch_sse_response(&patch_content, "call1234")?,
|
||||
create_final_assistant_message_sse_response("Patch has been applied successfully!")?,
|
||||
])
|
||||
.await?;
|
||||
|
||||
// Send a "codex" tool request that will trigger the apply_patch command
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
cwd: Some(cwd.path().to_string_lossy().to_string()),
|
||||
prompt: "please modify the test file".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let elicitation_request = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_request_message(),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let elicitation_request_id = RequestId::Integer(0);
|
||||
|
||||
let mut expected_changes = HashMap::new();
|
||||
expected_changes.insert(
|
||||
test_file.as_path().to_path_buf(),
|
||||
FileChange::Update {
|
||||
unified_diff: "@@ -1 +1 @@\n-original content\n+modified content\n".to_string(),
|
||||
move_path: None,
|
||||
},
|
||||
);
|
||||
|
||||
let expected_elicitation_request = create_expected_patch_approval_elicitation_request(
|
||||
elicitation_request_id.clone(),
|
||||
expected_changes,
|
||||
None, // No grant_root expected
|
||||
None, // No reason expected
|
||||
codex_request_id.to_string(),
|
||||
"1".to_string(),
|
||||
)?;
|
||||
assert_eq!(expected_elicitation_request, elicitation_request);
|
||||
|
||||
// Accept the patch approval request by responding to the elicitation
|
||||
mcp_process
|
||||
.send_response(
|
||||
elicitation_request_id,
|
||||
serde_json::to_value(PatchApprovalResponse {
|
||||
decision: ReviewDecision::Approved,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Verify the original `codex` tool call completes
|
||||
let codex_response = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(codex_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "Patch has been applied successfully!",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}),
|
||||
},
|
||||
codex_response
|
||||
);
|
||||
|
||||
let file_contents = std::fs::read_to_string(test_file.as_path())?;
|
||||
assert_eq!(file_contents, "modified content\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_codex_tool_passes_base_instructions() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Apparently `#[tokio::test]` must return `()`, so we create a helper
|
||||
// function that returns `Result` so we can use `?` in favor of `unwrap`.
|
||||
if let Err(err) = codex_tool_passes_base_instructions().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn codex_tool_passes_base_instructions() -> anyhow::Result<()> {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
let server =
|
||||
create_mock_chat_completions_server(vec![create_final_assistant_message_sse_response(
|
||||
"Enjoy!",
|
||||
)?])
|
||||
.await;
|
||||
|
||||
// Run `codex mcp` with a specific config.toml.
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
|
||||
// Send a "codex" tool request, which should hit the completions endpoint.
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(CodexToolCallParam {
|
||||
prompt: "How are you?".to_string(),
|
||||
base_instructions: Some("You are a helpful assistant.".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
|
||||
let codex_response = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(codex_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "Enjoy!",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}),
|
||||
},
|
||||
codex_response
|
||||
);
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
let request = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let instructions = request["messages"][0]["content"].as_str().unwrap();
|
||||
assert!(instructions.starts_with("You are a helpful assistant."));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_expected_patch_approval_elicitation_request(
|
||||
elicitation_request_id: RequestId,
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
grant_root: Option<PathBuf>,
|
||||
reason: Option<String>,
|
||||
codex_mcp_tool_call_id: String,
|
||||
codex_event_id: String,
|
||||
) -> anyhow::Result<JSONRPCRequest> {
|
||||
let mut message_lines = Vec::new();
|
||||
if let Some(r) = &reason {
|
||||
message_lines.push(r.clone());
|
||||
}
|
||||
message_lines.push("Allow Codex to apply proposed code changes?".to_string());
|
||||
|
||||
Ok(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: elicitation_request_id,
|
||||
method: ElicitRequest::METHOD.to_string(),
|
||||
params: Some(serde_json::to_value(&PatchApprovalElicitRequestParams {
|
||||
message: message_lines.join("\n"),
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "patch-approval".to_string(),
|
||||
codex_mcp_tool_call_id,
|
||||
codex_event_id,
|
||||
codex_reason: reason,
|
||||
codex_grant_root: grant_root,
|
||||
codex_changes: changes,
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
/// This handle is used to ensure that the MockServer and TempDir are not dropped while
|
||||
/// the McpProcess is still running.
|
||||
pub struct McpHandle {
|
||||
pub process: McpProcess,
|
||||
/// Retain the server for the lifetime of the McpProcess.
|
||||
#[allow(dead_code)]
|
||||
server: MockServer,
|
||||
/// Retain the temporary directory for the lifetime of the McpProcess.
|
||||
#[allow(dead_code)]
|
||||
dir: TempDir,
|
||||
}
|
||||
|
||||
async fn create_mcp_process(responses: Vec<String>) -> anyhow::Result<McpHandle> {
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
Ok(McpHandle {
|
||||
process: mcp_process,
|
||||
server,
|
||||
dir: codex_home,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a Codex config that uses the mock server as the model provider.
|
||||
/// It also uses `approval_policy = "untrusted"` so that we exercise the
|
||||
/// elicitation code path for shell commands.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "untrusted"
|
||||
sandbox_policy = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -1,250 +0,0 @@
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::ChildStdin;
|
||||
use tokio::process::ChildStdout;
|
||||
|
||||
use anyhow::Context;
|
||||
use assert_cmd::prelude::*;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
use mcp_types::JSONRPCMessage;
|
||||
use mcp_types::JSONRPCNotification;
|
||||
use mcp_types::JSONRPCRequest;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::ModelContextProtocolNotification;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::process::Command as StdCommand;
|
||||
use tokio::process::Command;
|
||||
|
||||
pub struct McpProcess {
|
||||
next_request_id: AtomicI64,
|
||||
/// Retain this child process until the client is dropped. The Tokio runtime
|
||||
/// will make a "best effort" to reap the process after it exits, but it is
|
||||
/// not a guarantee. See the `kill_on_drop` documentation for details.
|
||||
#[allow(dead_code)]
|
||||
process: Child,
|
||||
stdin: ChildStdin,
|
||||
stdout: BufReader<ChildStdout>,
|
||||
}
|
||||
|
||||
impl McpProcess {
|
||||
pub async fn new(codex_home: &Path) -> anyhow::Result<Self> {
|
||||
// Use assert_cmd to locate the binary path and then switch to tokio::process::Command
|
||||
let std_cmd = StdCommand::cargo_bin("codex-mcp-server")
|
||||
.context("should find binary for codex-mcp-server")?;
|
||||
|
||||
let program = std_cmd.get_program().to_owned();
|
||||
|
||||
let mut cmd = Command::new(program);
|
||||
|
||||
cmd.stdin(Stdio::piped());
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.env("CODEX_HOME", codex_home);
|
||||
cmd.env("RUST_LOG", "debug");
|
||||
|
||||
let mut process = cmd
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.context("codex-mcp-server proc should start")?;
|
||||
let stdin = process
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::format_err!("mcp should have stdin fd"))?;
|
||||
let stdout = process
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::format_err!("mcp should have stdout fd"))?;
|
||||
let stdout = BufReader::new(stdout);
|
||||
Ok(Self {
|
||||
next_request_id: AtomicI64::new(0),
|
||||
process,
|
||||
stdin,
|
||||
stdout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Performs the initialization handshake with the MCP server.
|
||||
pub async fn initialize(&mut self) -> anyhow::Result<()> {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let params = InitializeRequestParams {
|
||||
capabilities: ClientCapabilities {
|
||||
elicitation: Some(json!({})),
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "elicitation test".into(),
|
||||
title: Some("Elicitation Test".into()),
|
||||
version: "0.0.0".into(),
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.into(),
|
||||
};
|
||||
let params_value = serde_json::to_value(params)?;
|
||||
|
||||
self.send_jsonrpc_message(JSONRPCMessage::Request(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(request_id),
|
||||
method: mcp_types::InitializeRequest::METHOD.into(),
|
||||
params: Some(params_value),
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let initialized = self.read_jsonrpc_message().await?;
|
||||
assert_eq!(
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(request_id),
|
||||
result: json!({
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": true
|
||||
},
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "codex-mcp-server",
|
||||
"title": "Codex",
|
||||
"version": "0.0.0"
|
||||
},
|
||||
"protocolVersion": mcp_types::MCP_SCHEMA_VERSION
|
||||
})
|
||||
}),
|
||||
initialized
|
||||
);
|
||||
|
||||
// Send notifications/initialized to ack the response.
|
||||
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
method: mcp_types::InitializedNotification::METHOD.into(),
|
||||
params: None,
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the id used to make the request so it can be used when
|
||||
/// correlating notifications.
|
||||
pub async fn send_codex_tool_call(
|
||||
&mut self,
|
||||
params: CodexToolCallParam,
|
||||
) -> anyhow::Result<i64> {
|
||||
let codex_tool_call_params = CallToolRequestParams {
|
||||
name: "codex".to_string(),
|
||||
arguments: Some(serde_json::to_value(params)?),
|
||||
};
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(codex_tool_call_params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> anyhow::Result<i64> {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(request_id),
|
||||
method: method.to_string(),
|
||||
params,
|
||||
});
|
||||
self.send_jsonrpc_message(message).await?;
|
||||
Ok(request_id)
|
||||
}
|
||||
|
||||
pub async fn send_response(
|
||||
&mut self,
|
||||
id: RequestId,
|
||||
result: serde_json::Value,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_jsonrpc_message(JSONRPCMessage::Response(JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id,
|
||||
result,
|
||||
}))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_jsonrpc_message(&mut self, message: JSONRPCMessage) -> anyhow::Result<()> {
|
||||
let payload = serde_json::to_string(&message)?;
|
||||
self.stdin.write_all(payload.as_bytes()).await?;
|
||||
self.stdin.write_all(b"\n").await?;
|
||||
self.stdin.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_jsonrpc_message(&mut self) -> anyhow::Result<JSONRPCMessage> {
|
||||
let mut line = String::new();
|
||||
self.stdout.read_line(&mut line).await?;
|
||||
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
eprint!("message: {message:?}");
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(_) => {
|
||||
eprintln!("notification: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Request(jsonrpc_request) => {
|
||||
return Ok(jsonrpc_request);
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_stream_until_response_message(
|
||||
&mut self,
|
||||
request_id: RequestId,
|
||||
) -> anyhow::Result<JSONRPCResponse> {
|
||||
loop {
|
||||
let message = self.read_jsonrpc_message().await?;
|
||||
eprint!("message: {message:?}");
|
||||
|
||||
match message {
|
||||
JSONRPCMessage::Notification(_) => {
|
||||
eprintln!("notification: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Request(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Error(_) => {
|
||||
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||
}
|
||||
JSONRPCMessage::Response(jsonrpc_response) => {
|
||||
if jsonrpc_response.id == request_id {
|
||||
return Ok(jsonrpc_response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
/// Create a mock server that will provide the responses, in order, for
|
||||
/// requests to the `/v1/chat/completions` endpoint.
|
||||
pub async fn create_mock_chat_completions_server(responses: Vec<String>) -> MockServer {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let num_calls = responses.len();
|
||||
let seq_responder = SeqResponder {
|
||||
num_calls: AtomicUsize::new(0),
|
||||
responses,
|
||||
};
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(seq_responder)
|
||||
.expect(num_calls as u64)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
server
|
||||
}
|
||||
|
||||
struct SeqResponder {
|
||||
num_calls: AtomicUsize,
|
||||
responses: Vec<String>,
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
||||
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
||||
match self.responses.get(call_num) {
|
||||
Some(response) => ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(response.clone(), "text/event-stream"),
|
||||
None => panic!("no response for {call_num}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
mod mcp_process;
|
||||
mod mock_model_server;
|
||||
mod responses;
|
||||
|
||||
pub use mcp_process::McpProcess;
|
||||
pub use mock_model_server::create_mock_chat_completions_server;
|
||||
pub use responses::create_apply_patch_sse_response;
|
||||
pub use responses::create_final_assistant_message_sse_response;
|
||||
pub use responses::create_shell_sse_response;
|
||||
@@ -1,95 +0,0 @@
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
|
||||
pub fn create_shell_sse_response(
|
||||
command: Vec<String>,
|
||||
workdir: Option<&Path>,
|
||||
timeout_ms: Option<u64>,
|
||||
call_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// The `arguments`` for the `shell` tool is a serialized JSON object.
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"command": command,
|
||||
"workdir": workdir.map(|w| w.to_string_lossy()),
|
||||
"timeout": timeout_ms
|
||||
}))?;
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result<String> {
|
||||
let assistant_message = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": message
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&assistant_message)?
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
pub fn create_apply_patch_sse_response(
|
||||
patch_content: &str,
|
||||
call_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// Use shell command to call apply_patch with heredoc format
|
||||
let shell_command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF");
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"command": ["bash", "-lc", shell_command]
|
||||
}))?;
|
||||
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Types for Model Context Protocol. Inspired by https://crates.io/crates/lsp-types.
|
||||
|
||||
As documented on https://modelcontextprotocol.io/specification/2025-06-18/basic:
|
||||
As documented on https://modelcontextprotocol.io/specification/2025-03-26/basic:
|
||||
|
||||
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.ts
|
||||
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.json
|
||||
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts
|
||||
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa: E501
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -14,13 +13,10 @@ from pathlib import Path
|
||||
# Helper first so it is defined when other functions call it.
|
||||
from typing import Any, Literal
|
||||
|
||||
SCHEMA_VERSION = "2025-06-18"
|
||||
SCHEMA_VERSION = "2025-03-26"
|
||||
JSONRPC_VERSION = "2.0"
|
||||
|
||||
STANDARD_DERIVE = "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]\n"
|
||||
STANDARD_HASHABLE_DERIVE = (
|
||||
"#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]\n"
|
||||
)
|
||||
|
||||
# Will be populated with the schema's `definitions` map in `main()` so that
|
||||
# helper functions (for example `define_any_of`) can perform look-ups while
|
||||
@@ -30,27 +26,19 @@ DEFINITIONS: dict[str, Any] = {}
|
||||
CLIENT_REQUEST_TYPE_NAMES: list[str] = []
|
||||
# Concrete *Notification types that make up the ServerNotification enum.
|
||||
SERVER_NOTIFICATION_TYPE_NAMES: list[str] = []
|
||||
# Enum types that will need a `allow(clippy::large_enum_variant)` annotation in
|
||||
# order to compile without warnings.
|
||||
LARGE_ENUMS = {"ServerResult"}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Embed, cluster and analyse text prompts via the OpenAI API.",
|
||||
)
|
||||
|
||||
default_schema_file = (
|
||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"schema_file",
|
||||
nargs="?",
|
||||
default=default_schema_file,
|
||||
help="schema.json file to process",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
schema_file = args.schema_file
|
||||
num_args = len(sys.argv)
|
||||
if num_args == 1:
|
||||
schema_file = (
|
||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||
)
|
||||
elif num_args == 2:
|
||||
schema_file = Path(sys.argv[1])
|
||||
else:
|
||||
print("Usage: python3 codegen.py <schema.json>")
|
||||
return 1
|
||||
|
||||
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
||||
|
||||
@@ -209,8 +197,6 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non
|
||||
if name.endswith("Result"):
|
||||
out.extend(f"impl From<{name}> for serde_json::Value {{\n")
|
||||
out.append(f" fn from(value: {name}) -> Self {{\n")
|
||||
out.append(" // Leave this as it should never fail\n")
|
||||
out.append(" #[expect(clippy::unwrap_used)]\n")
|
||||
out.append(" serde_json::to_value(value).unwrap()\n")
|
||||
out.append(" }\n")
|
||||
out.append("}\n\n")
|
||||
@@ -225,7 +211,20 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non
|
||||
any_of = definition.get("anyOf", [])
|
||||
if any_of:
|
||||
assert isinstance(any_of, list)
|
||||
out.extend(define_any_of(name, any_of, description))
|
||||
if name == "JSONRPCMessage":
|
||||
# Special case for JSONRPCMessage because its definition in the
|
||||
# JSON schema does not quite match how we think about this type
|
||||
# definition in Rust.
|
||||
deep_copied_any_of = json.loads(json.dumps(any_of))
|
||||
deep_copied_any_of[2] = {
|
||||
"$ref": "#/definitions/JSONRPCBatchRequest",
|
||||
}
|
||||
deep_copied_any_of[5] = {
|
||||
"$ref": "#/definitions/JSONRPCBatchResponse",
|
||||
}
|
||||
out.extend(define_any_of(name, deep_copied_any_of, description))
|
||||
else:
|
||||
out.extend(define_any_of(name, any_of, description))
|
||||
return
|
||||
|
||||
type_prop = definition.get("type", None)
|
||||
@@ -394,7 +393,7 @@ def define_string_enum(
|
||||
|
||||
|
||||
def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> None:
|
||||
out.append(STANDARD_HASHABLE_DERIVE)
|
||||
out.append(STANDARD_DERIVE)
|
||||
out.append("#[serde(untagged)]\n")
|
||||
out.append(f"pub enum {name} {{\n")
|
||||
for simple_type in type_list:
|
||||
@@ -440,8 +439,6 @@ def define_any_of(
|
||||
if serde := get_serde_annotation_for_anyof_type(name):
|
||||
out.append(serde + "\n")
|
||||
|
||||
if name in LARGE_ENUMS:
|
||||
out.append("#[allow(clippy::large_enum_variant)]\n")
|
||||
out.append(f"pub enum {name} {{\n")
|
||||
|
||||
if name == "ClientRequest":
|
||||
@@ -599,8 +596,6 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
||||
prop_name = "r#type"
|
||||
elif name == "ref":
|
||||
prop_name = "r#ref"
|
||||
elif name == "enum":
|
||||
prop_name = "r#enum"
|
||||
elif snake_case := to_snake_case(name):
|
||||
prop_name = snake_case
|
||||
is_rename = True
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
pub const MCP_SCHEMA_VERSION: &str = "2025-06-18";
|
||||
pub const MCP_SCHEMA_VERSION: &str = "2025-03-26";
|
||||
pub const JSONRPC_VERSION: &str = "2.0";
|
||||
|
||||
/// Paired request/response types for the Model Context Protocol (MCP).
|
||||
@@ -35,12 +35,6 @@ fn default_jsonrpc() -> String {
|
||||
pub struct Annotations {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub audience: Option<Vec<Role>>,
|
||||
#[serde(
|
||||
rename = "lastModified",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub last_modified: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<f64>,
|
||||
}
|
||||
@@ -56,14 +50,6 @@ pub struct AudioContent {
|
||||
pub r#type: String, // &'static str = "audio"
|
||||
}
|
||||
|
||||
/// Base interface for metadata with name (identifier) and title (display name) properties.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BaseMetadata {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BlobResourceContents {
|
||||
pub blob: String,
|
||||
@@ -72,17 +58,6 @@ pub struct BlobResourceContents {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct BooleanSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "boolean"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum CallToolRequest {}
|
||||
|
||||
@@ -100,17 +75,29 @@ pub struct CallToolRequestParams {
|
||||
}
|
||||
|
||||
/// The server's response to a tool call.
|
||||
///
|
||||
/// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
/// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
/// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
/// and self-correct.
|
||||
///
|
||||
/// However, any errors in _finding_ the tool, an error indicating that the
|
||||
/// server does not support tool calls, or any other exceptional conditions,
|
||||
/// should be reported as an MCP error response.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CallToolResult {
|
||||
pub content: Vec<ContentBlock>,
|
||||
pub content: Vec<CallToolResultContent>,
|
||||
#[serde(rename = "isError", default, skip_serializing_if = "Option::is_none")]
|
||||
pub is_error: Option<bool>,
|
||||
#[serde(
|
||||
rename = "structuredContent",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub structured_content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CallToolResultContent {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
impl From<CallToolResult> for serde_json::Value {
|
||||
@@ -140,8 +127,6 @@ pub struct CancelledNotificationParams {
|
||||
/// Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ClientCapabilities {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub elicitation: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub experimental: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
@@ -209,7 +194,6 @@ pub enum ClientResult {
|
||||
Result(Result),
|
||||
CreateMessageResult(CreateMessageResult),
|
||||
ListRootsResult(ListRootsResult),
|
||||
ElicitResult(ElicitResult),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -224,18 +208,9 @@ impl ModelContextProtocolRequest for CompleteRequest {
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParams {
|
||||
pub argument: CompleteRequestParamsArgument,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub context: Option<CompleteRequestParamsContext>,
|
||||
pub r#ref: CompleteRequestParamsRef,
|
||||
}
|
||||
|
||||
/// Additional, optional context for completions
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParamsContext {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// The argument's information
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct CompleteRequestParamsArgument {
|
||||
@@ -247,7 +222,7 @@ pub struct CompleteRequestParamsArgument {
|
||||
#[serde(untagged)]
|
||||
pub enum CompleteRequestParamsRef {
|
||||
PromptReference(PromptReference),
|
||||
ResourceTemplateReference(ResourceTemplateReference),
|
||||
ResourceReference(ResourceReference),
|
||||
}
|
||||
|
||||
/// The server's response to a completion/complete request
|
||||
@@ -273,16 +248,6 @@ impl From<CompleteResult> for serde_json::Value {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentBlock {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
ResourceLink(ResourceLink),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum CreateMessageRequest {}
|
||||
|
||||
@@ -360,48 +325,6 @@ impl From<CreateMessageResult> for serde_json::Value {
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct Cursor(String);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ElicitRequest {}
|
||||
|
||||
impl ModelContextProtocolRequest for ElicitRequest {
|
||||
const METHOD: &'static str = "elicitation/create";
|
||||
type Params = ElicitRequestParams;
|
||||
type Result = ElicitResult;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitRequestParams {
|
||||
pub message: String,
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
}
|
||||
|
||||
/// A restricted subset of JSON Schema.
|
||||
/// Only top-level properties are allowed, without nesting.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitRequestParamsRequestedSchema {
|
||||
pub properties: serde_json::Value,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<Vec<String>>,
|
||||
pub r#type: String, // &'static str = "object"
|
||||
}
|
||||
|
||||
/// The client's response to an elicitation request.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ElicitResult {
|
||||
pub action: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl From<ElicitResult> for serde_json::Value {
|
||||
fn from(value: ElicitResult) -> Self {
|
||||
// Leave this as it should never fail
|
||||
#[expect(clippy::unwrap_used)]
|
||||
serde_json::to_value(value).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// The contents of a resource, embedded into a prompt or tool call result.
|
||||
///
|
||||
/// It is up to the client how best to render embedded resources for the benefit
|
||||
@@ -423,18 +346,6 @@ pub enum EmbeddedResourceResource {
|
||||
|
||||
pub type EmptyResult = Result;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct EnumSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub r#enum: Vec<String>,
|
||||
#[serde(rename = "enumNames", default, skip_serializing_if = "Option::is_none")]
|
||||
pub enum_names: Option<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "string"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum GetPromptRequest {}
|
||||
|
||||
@@ -478,12 +389,10 @@ pub struct ImageContent {
|
||||
pub r#type: String, // &'static str = "image"
|
||||
}
|
||||
|
||||
/// Describes the name and version of an MCP implementation, with an optional title for UI representation.
|
||||
/// Describes the name and version of an MCP implementation.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct Implementation {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
@@ -533,6 +442,24 @@ impl ModelContextProtocolNotification for InitializedNotification {
|
||||
type Params = Option<serde_json::Value>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JSONRPCBatchRequestItem {
|
||||
JSONRPCRequest(JSONRPCRequest),
|
||||
JSONRPCNotification(JSONRPCNotification),
|
||||
}
|
||||
|
||||
pub type JSONRPCBatchRequest = Vec<JSONRPCBatchRequestItem>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JSONRPCBatchResponseItem {
|
||||
JSONRPCResponse(JSONRPCResponse),
|
||||
JSONRPCError(JSONRPCError),
|
||||
}
|
||||
|
||||
pub type JSONRPCBatchResponse = Vec<JSONRPCBatchResponseItem>;
|
||||
|
||||
/// A response to a request that indicates an error occurred.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct JSONRPCError {
|
||||
@@ -556,8 +483,10 @@ pub struct JSONRPCErrorError {
|
||||
pub enum JSONRPCMessage {
|
||||
Request(JSONRPCRequest),
|
||||
Notification(JSONRPCNotification),
|
||||
BatchRequest(JSONRPCBatchRequest),
|
||||
Response(JSONRPCResponse),
|
||||
Error(JSONRPCError),
|
||||
BatchResponse(JSONRPCBatchResponse),
|
||||
}
|
||||
|
||||
/// A notification which does not expect a response.
|
||||
@@ -848,19 +777,6 @@ pub struct Notification {
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct NumberSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub maximum: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub minimum: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PaginatedRequest {
|
||||
pub method: String,
|
||||
@@ -901,17 +817,6 @@ impl ModelContextProtocolRequest for PingRequest {
|
||||
type Result = Result;
|
||||
}
|
||||
|
||||
/// Restricted schema definitions that only allow primitive types
|
||||
/// without nested objects or arrays.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PrimitiveSchemaDefinition {
|
||||
StringSchema(StringSchema),
|
||||
NumberSchema(NumberSchema),
|
||||
BooleanSchema(BooleanSchema),
|
||||
EnumSchema(EnumSchema),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ProgressNotification {}
|
||||
|
||||
@@ -931,7 +836,7 @@ pub struct ProgressNotificationParams {
|
||||
pub total: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProgressToken {
|
||||
String(String),
|
||||
@@ -946,8 +851,6 @@ pub struct Prompt {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
/// Describes an argument that a prompt can accept.
|
||||
@@ -958,8 +861,6 @@ pub struct PromptArgument {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -976,16 +877,23 @@ impl ModelContextProtocolNotification for PromptListChangedNotification {
|
||||
/// resources from the MCP server.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PromptMessage {
|
||||
pub content: ContentBlock,
|
||||
pub content: PromptMessageContent,
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PromptMessageContent {
|
||||
TextContent(TextContent),
|
||||
ImageContent(ImageContent),
|
||||
AudioContent(AudioContent),
|
||||
EmbeddedResource(EmbeddedResource),
|
||||
}
|
||||
|
||||
/// Identifies a prompt.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct PromptReference {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "ref/prompt"
|
||||
}
|
||||
|
||||
@@ -1031,7 +939,7 @@ pub struct Request {
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum RequestId {
|
||||
String(String),
|
||||
@@ -1050,8 +958,6 @@ pub struct Resource {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
@@ -1063,26 +969,6 @@ pub struct ResourceContents {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
/// A resource that the server is capable of reading, included in a prompt or tool call result.
|
||||
///
|
||||
/// Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceLink {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub annotations: Option<Annotations>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "resource_link"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ResourceListChangedNotification {}
|
||||
|
||||
@@ -1091,6 +977,13 @@ impl ModelContextProtocolNotification for ResourceListChangedNotification {
|
||||
type Params = Option<serde_json::Value>;
|
||||
}
|
||||
|
||||
/// A reference to a resource or resource template definition.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceReference {
|
||||
pub r#type: String, // &'static str = "ref/resource"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
/// A template description for resources available on the server.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceTemplate {
|
||||
@@ -1101,19 +994,10 @@ pub struct ResourceTemplate {
|
||||
#[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
#[serde(rename = "uriTemplate")]
|
||||
pub uri_template: String,
|
||||
}
|
||||
|
||||
/// A reference to a resource or resource template definition.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ResourceTemplateReference {
|
||||
pub r#type: String, // &'static str = "ref/resource"
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ResourceUpdatedNotification {}
|
||||
|
||||
@@ -1256,7 +1140,6 @@ pub enum ServerRequest {
|
||||
PingRequest(PingRequest),
|
||||
CreateMessageRequest(CreateMessageRequest),
|
||||
ListRootsRequest(ListRootsRequest),
|
||||
ElicitRequest(ElicitRequest),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
@@ -1289,21 +1172,6 @@ pub struct SetLevelRequestParams {
|
||||
pub level: LoggingLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct StringSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<String>,
|
||||
#[serde(rename = "maxLength", default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_length: Option<i64>,
|
||||
#[serde(rename = "minLength", default, skip_serializing_if = "Option::is_none")]
|
||||
pub min_length: Option<i64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub r#type: String, // &'static str = "string"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub enum SubscribeRequest {}
|
||||
|
||||
@@ -1345,25 +1213,6 @@ pub struct Tool {
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: ToolInputSchema,
|
||||
pub name: String,
|
||||
#[serde(
|
||||
rename = "outputSchema",
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub output_schema: Option<ToolOutputSchema>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
/// An optional JSON Schema object defining the structure of the tool's output returned in
|
||||
/// the structuredContent field of a CallToolResult.
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ToolOutputSchema {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub properties: Option<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<Vec<String>>,
|
||||
pub r#type: String, // &'static str = "object"
|
||||
}
|
||||
|
||||
/// A JSON Schema object defining the expected parameters for the tool.
|
||||
|
||||
@@ -17,8 +17,8 @@ fn deserialize_initialize_request() {
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"capabilities": {},
|
||||
"clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-06-18"
|
||||
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-03-26"
|
||||
}
|
||||
}"#;
|
||||
|
||||
@@ -37,8 +37,8 @@ fn deserialize_initialize_request() {
|
||||
method: "initialize".into(),
|
||||
params: Some(json!({
|
||||
"capabilities": {},
|
||||
"clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-06-18"
|
||||
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
|
||||
"protocolVersion": "2025-03-26"
|
||||
})),
|
||||
};
|
||||
|
||||
@@ -57,14 +57,12 @@ fn deserialize_initialize_request() {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "acme-client".into(),
|
||||
title: Some("Acme".to_string()),
|
||||
version: "1.2.3".into(),
|
||||
},
|
||||
protocol_version: "2025-06-18".into(),
|
||||
protocol_version: "2025-03-26".into(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
[toolchain]
|
||||
channel = "1.88.0"
|
||||
components = [ "clippy", "rustfmt", "rust-src"]
|
||||
@@ -18,16 +18,8 @@ use crossterm::event::KeyEvent;
|
||||
use crossterm::event::MouseEvent;
|
||||
use crossterm::event::MouseEventKind;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc::Receiver;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Time window for debouncing redraw requests.
|
||||
const REDRAW_DEBOUNCE: Duration = Duration::from_millis(10);
|
||||
|
||||
/// Top-level application state: which full-screen view is currently active.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
@@ -54,9 +46,6 @@ pub(crate) struct App<'a> {
|
||||
|
||||
file_search: FileSearchManager,
|
||||
|
||||
/// True when a redraw has been scheduled but not yet executed.
|
||||
pending_redraw: Arc<AtomicBool>,
|
||||
|
||||
/// Stored parameters needed to instantiate the ChatWidget later, e.g.,
|
||||
/// after dismissing the Git-repo warning.
|
||||
chat_args: Option<ChatWidgetArgs>,
|
||||
@@ -71,7 +60,7 @@ struct ChatWidgetArgs {
|
||||
initial_images: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl App<'_> {
|
||||
impl<'a> App<'a> {
|
||||
pub(crate) fn new(
|
||||
config: Config,
|
||||
initial_prompt: Option<String>,
|
||||
@@ -81,7 +70,6 @@ impl App<'_> {
|
||||
) -> Self {
|
||||
let (app_event_tx, app_event_rx) = channel();
|
||||
let app_event_tx = AppEventSender::new(app_event_tx);
|
||||
let pending_redraw = Arc::new(AtomicBool::new(false));
|
||||
let scroll_event_helper = ScrollEventHelper::new(app_event_tx.clone());
|
||||
|
||||
// Spawn a dedicated thread for reading the crossterm event loop and
|
||||
@@ -95,7 +83,7 @@ impl App<'_> {
|
||||
app_event_tx.send(AppEvent::KeyEvent(key_event));
|
||||
}
|
||||
crossterm::event::Event::Resize(_, _) => {
|
||||
app_event_tx.send(AppEvent::RequestRedraw);
|
||||
app_event_tx.send(AppEvent::Redraw);
|
||||
}
|
||||
crossterm::event::Event::Mouse(MouseEvent {
|
||||
kind: MouseEventKind::ScrollUp,
|
||||
@@ -164,7 +152,6 @@ impl App<'_> {
|
||||
app_state,
|
||||
config,
|
||||
file_search,
|
||||
pending_redraw,
|
||||
chat_args,
|
||||
}
|
||||
}
|
||||
@@ -175,28 +162,6 @@ impl App<'_> {
|
||||
self.app_event_tx.clone()
|
||||
}
|
||||
|
||||
/// Schedule a redraw if one is not already pending.
|
||||
#[allow(clippy::unwrap_used)]
|
||||
fn schedule_redraw(&self) {
|
||||
// Attempt to set the flag to `true`. If it was already `true`, another
|
||||
// redraw is already pending so we can return early.
|
||||
if self
|
||||
.pending_redraw
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let tx = self.app_event_tx.clone();
|
||||
let pending_redraw = self.pending_redraw.clone();
|
||||
thread::spawn(move || {
|
||||
thread::sleep(REDRAW_DEBOUNCE);
|
||||
tx.send(AppEvent::Redraw);
|
||||
pending_redraw.store(false, Ordering::SeqCst);
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn run(
|
||||
&mut self,
|
||||
terminal: &mut tui::Tui,
|
||||
@@ -204,13 +169,10 @@ impl App<'_> {
|
||||
) -> Result<()> {
|
||||
// Insert an event to trigger the first render.
|
||||
let app_event_tx = self.app_event_tx.clone();
|
||||
app_event_tx.send(AppEvent::RequestRedraw);
|
||||
app_event_tx.send(AppEvent::Redraw);
|
||||
|
||||
while let Ok(event) = self.app_event_rx.recv() {
|
||||
match event {
|
||||
AppEvent::RequestRedraw => {
|
||||
self.schedule_redraw();
|
||||
}
|
||||
AppEvent::Redraw => {
|
||||
self.draw_next_frame(terminal)?;
|
||||
}
|
||||
@@ -287,7 +249,7 @@ impl App<'_> {
|
||||
Vec::new(),
|
||||
));
|
||||
self.app_state = AppState::Chat { widget: new_widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
self.app_event_tx.send(AppEvent::Redraw);
|
||||
}
|
||||
SlashCommand::ToggleMouseMode => {
|
||||
if let Err(e) = mouse_capture.toggle() {
|
||||
@@ -374,7 +336,7 @@ impl App<'_> {
|
||||
args.initial_images,
|
||||
));
|
||||
self.app_state = AppState::Chat { widget };
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
self.app_event_tx.send(AppEvent::Redraw);
|
||||
}
|
||||
GitWarningOutcome::Quit => {
|
||||
self.app_event_tx.send(AppEvent::ExitRequest);
|
||||
|
||||
@@ -8,10 +8,6 @@ use crate::slash_command::SlashCommand;
|
||||
pub(crate) enum AppEvent {
|
||||
CodexEvent(Event),
|
||||
|
||||
/// Request a redraw which will be debounced by the [`App`].
|
||||
RequestRedraw,
|
||||
|
||||
/// Actually draw the next frame.
|
||||
Redraw,
|
||||
|
||||
KeyEvent(KeyEvent),
|
||||
|
||||
@@ -212,7 +212,7 @@ impl BottomPane<'_> {
|
||||
}
|
||||
|
||||
pub(crate) fn request_redraw(&self) {
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw)
|
||||
self.app_event_tx.send(AppEvent::Redraw)
|
||||
}
|
||||
|
||||
/// Returns true when a popup inside the composer is visible.
|
||||
|
||||
@@ -24,7 +24,7 @@ impl StatusIndicatorView {
|
||||
}
|
||||
}
|
||||
|
||||
impl BottomPaneView<'_> for StatusIndicatorView {
|
||||
impl<'a> BottomPaneView<'a> for StatusIndicatorView {
|
||||
fn update_status_text(&mut self, text: String) -> ConditionalUpdate {
|
||||
self.update_text(text);
|
||||
ConditionalUpdate::NeedsRedraw
|
||||
|
||||
@@ -96,15 +96,14 @@ impl ChatWidget<'_> {
|
||||
// Create the Codex asynchronously so the UI loads as quickly as possible.
|
||||
let config_for_agent_loop = config.clone();
|
||||
tokio::spawn(async move {
|
||||
let (codex, session_event, _ctrl_c, _session_id) =
|
||||
match init_codex(config_for_agent_loop).await {
|
||||
Ok(vals) => vals,
|
||||
Err(e) => {
|
||||
// TODO: surface this error to the user.
|
||||
tracing::error!("failed to initialize codex: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let (codex, session_event, _ctrl_c) = match init_codex(config_for_agent_loop).await {
|
||||
Ok(vals) => vals,
|
||||
Err(e) => {
|
||||
// TODO: surface this error to the user.
|
||||
tracing::error!("failed to initialize codex: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Forward the captured `SessionInitialized` event that was consumed
|
||||
// inside `init_codex()` so it can be rendered in the UI.
|
||||
@@ -432,7 +431,7 @@ impl ChatWidget<'_> {
|
||||
}
|
||||
|
||||
fn request_redraw(&mut self) {
|
||||
self.app_event_tx.send(AppEvent::RequestRedraw);
|
||||
self.app_event_tx.send(AppEvent::Redraw);
|
||||
}
|
||||
|
||||
pub(crate) fn add_diff_output(&mut self, diff_output: String) {
|
||||
@@ -465,8 +464,6 @@ impl ChatWidget<'_> {
|
||||
if self.bottom_pane.is_task_running() {
|
||||
self.bottom_pane.clear_ctrl_c_quit_hint();
|
||||
self.submit_op(Op::Interrupt);
|
||||
self.answer_buffer.clear();
|
||||
self.reasoning_buffer.clear();
|
||||
false
|
||||
} else if self.bottom_pane.ctrl_c_quit_hint_visible() {
|
||||
true
|
||||
|
||||
@@ -3,7 +3,7 @@ use codex_common::ApprovalModeCliArg;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version)]
|
||||
pub struct Cli {
|
||||
/// Optional user prompt to start the session.
|
||||
|
||||
@@ -17,7 +17,6 @@ use image::GenericImageView;
|
||||
use image::ImageReader;
|
||||
use lazy_static::lazy_static;
|
||||
use mcp_types::EmbeddedResourceResource;
|
||||
use mcp_types::ResourceLink;
|
||||
use ratatui::prelude::*;
|
||||
use ratatui::style::Color;
|
||||
use ratatui::style::Modifier;
|
||||
@@ -332,7 +331,8 @@ impl HistoryCell {
|
||||
) -> Option<Self> {
|
||||
match result {
|
||||
Ok(mcp_types::CallToolResult { content, .. }) => {
|
||||
if let Some(mcp_types::ContentBlock::ImageContent(image)) = content.first() {
|
||||
if let Some(mcp_types::CallToolResultContent::ImageContent(image)) = content.first()
|
||||
{
|
||||
let raw_data =
|
||||
match base64::engine::general_purpose::STANDARD.decode(&image.data) {
|
||||
Ok(data) => data,
|
||||
@@ -405,21 +405,21 @@ impl HistoryCell {
|
||||
|
||||
for tool_call_result in content {
|
||||
let line_text = match tool_call_result {
|
||||
mcp_types::ContentBlock::TextContent(text) => {
|
||||
mcp_types::CallToolResultContent::TextContent(text) => {
|
||||
format_and_truncate_tool_result(
|
||||
&text.text,
|
||||
TOOL_CALL_MAX_LINES,
|
||||
num_cols as usize,
|
||||
)
|
||||
}
|
||||
mcp_types::ContentBlock::ImageContent(_) => {
|
||||
mcp_types::CallToolResultContent::ImageContent(_) => {
|
||||
// TODO show images even if they're not the first result, will require a refactor of `CompletedMcpToolCall`
|
||||
"<image content>".to_string()
|
||||
}
|
||||
mcp_types::ContentBlock::AudioContent(_) => {
|
||||
mcp_types::CallToolResultContent::AudioContent(_) => {
|
||||
"<audio content>".to_string()
|
||||
}
|
||||
mcp_types::ContentBlock::EmbeddedResource(resource) => {
|
||||
mcp_types::CallToolResultContent::EmbeddedResource(resource) => {
|
||||
let uri = match resource.resource {
|
||||
EmbeddedResourceResource::TextResourceContents(text) => {
|
||||
text.uri
|
||||
@@ -430,9 +430,6 @@ impl HistoryCell {
|
||||
};
|
||||
format!("embedded resource: {uri}")
|
||||
}
|
||||
mcp_types::ContentBlock::ResourceLink(ResourceLink { uri, .. }) => {
|
||||
format!("link: {uri}")
|
||||
}
|
||||
};
|
||||
lines.push(Line::styled(line_text, Style::default().fg(Color::Gray)));
|
||||
}
|
||||
|
||||
@@ -75,7 +75,6 @@ pub fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> std::io::
|
||||
model_provider: None,
|
||||
config_profile: cli.config_profile.clone(),
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: None,
|
||||
};
|
||||
// Parse `-c` overrides from the CLI.
|
||||
let cli_kv_overrides = match cli.config_overrides.parse_overrides() {
|
||||
|
||||
@@ -65,7 +65,7 @@ impl StatusIndicatorWidget {
|
||||
std::thread::sleep(Duration::from_millis(200));
|
||||
counter = counter.wrapping_add(1);
|
||||
frame_idx_clone.store(counter, Ordering::Relaxed);
|
||||
app_event_tx_clone.send(AppEvent::RequestRedraw);
|
||||
app_event_tx_clone.send(AppEvent::Redraw);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user