mirror of
https://github.com/openai/codex.git
synced 2026-02-07 17:33:41 +00:00
Compare commits
13 Commits
scrollable
...
codex-conc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6b2c5c772 | ||
|
|
aed712286b | ||
|
|
6fcedb46a9 | ||
|
|
81bb1c9e26 | ||
|
|
7e0f506da2 | ||
|
|
929ba50adc | ||
|
|
80555d4ff2 | ||
|
|
97ab8fb610 | ||
|
|
fe62f859a6 | ||
|
|
92f3566d78 | ||
|
|
f20de21cb6 | ||
|
|
bc7beddaa2 | ||
|
|
8360c6a3ec |
@@ -2,7 +2,9 @@
|
||||
|
||||
In the codex-rs folder where the rust code lives:
|
||||
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR`. You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
- Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` or `CODEX_SANDBOX_ENV_VAR`.
|
||||
- You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations.
|
||||
- Similarly, when you spawn a process using Seatbelt (`/usr/bin/sandbox-exec`), `CODEX_SANDBOX=seatbelt` will be set on the child process. Integration tests that want to run Seatbelt themselves cannot be run under Seatbelt, so checks for `CODEX_SANDBOX=seatbelt` are also often used to early exit out of tests, as appropriate.
|
||||
|
||||
Before creating a pull request with changes to `codex-rs`, run `just fmt` (in `codex-rs` directory) to format the code and `just fix` (in `codex-rs` directory) to fix any linter issues in the code, ensure the test suite passes by running `cargo test --all-features` in the `codex-rs` directory.
|
||||
|
||||
|
||||
@@ -83,6 +83,7 @@ if (wantsNative && process.platform !== 'win32') {
|
||||
|
||||
const child = spawn(binaryPath, process.argv.slice(2), {
|
||||
stdio: "inherit",
|
||||
env: { ...process.env, CODEX_MANAGED_BY_NPM: "1" },
|
||||
});
|
||||
|
||||
child.on("error", (err) => {
|
||||
|
||||
13
codex-rs/Cargo.lock
generated
13
codex-rs/Cargo.lock
generated
@@ -695,6 +695,7 @@ dependencies = [
|
||||
"reqwest",
|
||||
"seccompiler",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"shlex",
|
||||
@@ -842,6 +843,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-ansi-escape",
|
||||
"codex-arg0",
|
||||
@@ -860,6 +862,8 @@ dependencies = [
|
||||
"ratatui",
|
||||
"ratatui-image",
|
||||
"regex-lite",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"strum 0.27.2",
|
||||
@@ -3951,6 +3955,15 @@ dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_bytes"
|
||||
version = "0.11.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8437fd221bde2d4ca316d61b90e337e9e702b3820b87d63caa9ba6c02bd06d96"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.219"
|
||||
|
||||
584
codex-rs/cli/src/concurrent/mod.rs
Normal file
584
codex-rs/cli/src/concurrent/mod.rs
Normal file
@@ -0,0 +1,584 @@
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::process::Stdio;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use tokio::process::Command as TokioCommand;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use anyhow::Context;
|
||||
use codex_common::CliConfigOverrides;
|
||||
use codex_exec::Cli as ExecCli;
|
||||
|
||||
// Serialize git worktree add operations across tasks to avoid repository lock contention.
|
||||
static GIT_WORKTREE_ADD_SEMAPHORE: OnceLock<Semaphore> = OnceLock::new();
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConcurrentRunResult {
|
||||
pub branch: String,
|
||||
pub worktree_dir: PathBuf,
|
||||
pub log_file: Option<PathBuf>,
|
||||
pub exec_exit_code: Option<i32>,
|
||||
pub _had_changes: bool,
|
||||
pub _applied_changes: Option<usize>,
|
||||
}
|
||||
|
||||
fn compute_codex_home() -> PathBuf {
|
||||
if let Ok(val) = std::env::var("CODEX_HOME") {
|
||||
if !val.is_empty() {
|
||||
return PathBuf::from(val);
|
||||
}
|
||||
}
|
||||
// Fallback to default (~/.codex) without requiring it to already exist.
|
||||
codex_core::config::find_codex_home().unwrap_or_else(|_| {
|
||||
let mut p = std::env::var_os("HOME")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_default();
|
||||
if p.as_os_str().is_empty() {
|
||||
return PathBuf::from(".codex");
|
||||
}
|
||||
p.push(".codex");
|
||||
p
|
||||
})
|
||||
}
|
||||
|
||||
fn slugify_prompt(prompt: &str, max_len: usize) -> String {
|
||||
let mut out = String::with_capacity(prompt.len());
|
||||
let mut prev_hyphen = false;
|
||||
for ch in prompt.chars() {
|
||||
let c = ch.to_ascii_lowercase();
|
||||
let keep = matches!(c, 'a'..='z' | '0'..='9');
|
||||
if keep {
|
||||
out.push(c);
|
||||
prev_hyphen = false;
|
||||
} else if c.is_ascii_whitespace() || matches!(c, '-' | '_' | '+') {
|
||||
if !prev_hyphen && !out.is_empty() {
|
||||
out.push('-');
|
||||
prev_hyphen = true;
|
||||
}
|
||||
} else {
|
||||
// skip other punctuation/symbols
|
||||
}
|
||||
if out.len() >= max_len {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Trim trailing hyphens
|
||||
while out.ends_with('-') {
|
||||
out.pop();
|
||||
}
|
||||
if out.is_empty() {
|
||||
"task".to_string()
|
||||
} else {
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn git_output(repo_dir: &Path, args: &[&str]) -> anyhow::Result<String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(repo_dir)
|
||||
.output()
|
||||
.with_context(|| format!("running git {args:?}"))?;
|
||||
if !out.status.success() {
|
||||
anyhow::bail!(
|
||||
"git {:?} failed with status {}: {}",
|
||||
args,
|
||||
out.status,
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
);
|
||||
}
|
||||
Ok(String::from_utf8_lossy(&out.stdout).trim().to_string())
|
||||
}
|
||||
|
||||
fn git_capture_stdout(repo_dir: &Path, args: &[&str]) -> anyhow::Result<Vec<u8>> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(repo_dir)
|
||||
.output()
|
||||
.with_context(|| format!("running git {args:?}"))?;
|
||||
if !out.status.success() {
|
||||
anyhow::bail!(
|
||||
"git {:?} failed with status {}: {}",
|
||||
args,
|
||||
out.status,
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
);
|
||||
}
|
||||
Ok(out.stdout)
|
||||
}
|
||||
|
||||
fn count_files_in_patch(diff: &[u8]) -> usize {
|
||||
// Count occurrences of lines starting with "diff --git ", which mark file boundaries.
|
||||
// This works for text and binary patches produced by `git diff --binary`.
|
||||
let mut count = 0usize;
|
||||
for line in diff.split(|&b| b == b'\n') {
|
||||
if line.starts_with(b"diff --git ") {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
count
|
||||
}
|
||||
|
||||
pub async fn run_concurrent_flow(
|
||||
prompt: String,
|
||||
cli_config_overrides: CliConfigOverrides,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
automerge: bool,
|
||||
quiet: bool,
|
||||
) -> anyhow::Result<ConcurrentRunResult> {
|
||||
let cwd = std::env::current_dir()?;
|
||||
|
||||
// Ensure we are in a git repo and find repo root.
|
||||
let repo_root_str = git_output(&cwd, &["rev-parse", "--show-toplevel"]);
|
||||
let repo_root = match repo_root_str {
|
||||
Ok(p) => PathBuf::from(p),
|
||||
Err(err) => {
|
||||
eprintln!("Not inside a Git repo: {err}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
// Determine current branch and original head commit.
|
||||
let current_branch = git_output(&repo_root, &["rev-parse", "--abbrev-ref", "HEAD"])
|
||||
.unwrap_or_else(|_| "HEAD".to_string());
|
||||
let original_head =
|
||||
git_output(&repo_root, &["rev-parse", "HEAD"]).context("finding original HEAD commit")?;
|
||||
|
||||
// Build worktree target path under $CODEX_HOME/worktrees/<repo>/<branch>
|
||||
let mut codex_home = compute_codex_home();
|
||||
codex_home.push("worktrees");
|
||||
// repo name = last component of repo_root
|
||||
let repo_name = repo_root
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "repo".to_string());
|
||||
codex_home.push(repo_name.clone());
|
||||
|
||||
// Prepare branch name: codex/<slug>, retrying with a numeric suffix to avoid races.
|
||||
let slug = slugify_prompt(&prompt, 64);
|
||||
let mut branch: String;
|
||||
let worktree_dir: PathBuf;
|
||||
let mut attempt: u32 = 1;
|
||||
loop {
|
||||
branch = if attempt == 1 {
|
||||
format!("codex/{slug}")
|
||||
} else {
|
||||
format!("codex/{slug}-{attempt}")
|
||||
};
|
||||
|
||||
let mut candidate_dir = codex_home.clone();
|
||||
candidate_dir.push(&branch);
|
||||
|
||||
// Create parent directories for candidate path.
|
||||
if let Some(parent) = candidate_dir.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
if !quiet {
|
||||
println!(
|
||||
"Creating worktree at {} with branch {}",
|
||||
candidate_dir.display(),
|
||||
branch
|
||||
);
|
||||
}
|
||||
|
||||
// Try to add worktree with new branch from current HEAD
|
||||
let worktree_path_str = candidate_dir.to_string_lossy().to_string();
|
||||
let add_status = Command::new("git")
|
||||
.arg("worktree")
|
||||
.arg("add")
|
||||
.arg("-b")
|
||||
.arg(&branch)
|
||||
.arg(&worktree_path_str)
|
||||
.current_dir(&repo_root)
|
||||
.status()?;
|
||||
if add_status.success() {
|
||||
worktree_dir = candidate_dir;
|
||||
break;
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
if attempt > 50 {
|
||||
anyhow::bail!("Failed to create git worktree after multiple attempts");
|
||||
}
|
||||
// Retry with a new branch name.
|
||||
}
|
||||
|
||||
// Either run codex exec inline (verbose) or as a subprocess with logs redirected.
|
||||
let mut log_file: Option<PathBuf> = None;
|
||||
let mut exec_exit_code: Option<i32> = None;
|
||||
if quiet {
|
||||
let exe = std::env::current_exe()
|
||||
.map_err(|e| anyhow::anyhow!("failed to locate current executable: {e}"))?;
|
||||
|
||||
// Prepare logs directory: $CODEX_HOME/logs/<repo_name>
|
||||
let mut logs_dir = compute_codex_home();
|
||||
logs_dir.push("logs");
|
||||
logs_dir.push(&repo_name);
|
||||
std::fs::create_dir_all(&logs_dir)?;
|
||||
|
||||
let sanitized_branch = branch.replace('/', "_");
|
||||
let log_path = logs_dir.join(format!("{sanitized_branch}.log"));
|
||||
let log_f = File::create(&log_path)?;
|
||||
log_file = Some(log_path.clone());
|
||||
|
||||
let mut cmd = Command::new(exe);
|
||||
cmd.arg("exec")
|
||||
.arg("--full-auto")
|
||||
.arg("--cd")
|
||||
.arg(worktree_dir.as_os_str())
|
||||
.stdout(Stdio::from(log_f.try_clone()?))
|
||||
.stderr(Stdio::from(log_f));
|
||||
|
||||
// Forward any root-level config overrides.
|
||||
for ov in cli_config_overrides.raw_overrides.iter() {
|
||||
cmd.arg("-c").arg(ov);
|
||||
}
|
||||
|
||||
// Append the prompt last (positional argument).
|
||||
cmd.arg(&prompt);
|
||||
|
||||
let status = cmd.status()?;
|
||||
exec_exit_code = status.code();
|
||||
if !status.success() && !quiet {
|
||||
eprintln!("codex exec failed with exit code {exec_exit_code:?}");
|
||||
}
|
||||
} else {
|
||||
// Build an ExecCli to run in full-auto mode at the worktree directory.
|
||||
let mut exec_cli = ExecCli {
|
||||
images: vec![],
|
||||
model: None,
|
||||
sandbox_mode: None,
|
||||
config_profile: None,
|
||||
full_auto: true,
|
||||
dangerously_bypass_approvals_and_sandbox: false,
|
||||
cwd: Some(worktree_dir.clone()),
|
||||
skip_git_repo_check: false,
|
||||
config_overrides: CliConfigOverrides::default(),
|
||||
color: Default::default(),
|
||||
json: false,
|
||||
last_message_file: None,
|
||||
prompt: Some(prompt.clone()),
|
||||
};
|
||||
|
||||
// Prepend any root-level config overrides.
|
||||
super::prepend_config_flags(&mut exec_cli.config_overrides, cli_config_overrides);
|
||||
|
||||
// Run codex exec
|
||||
if let Err(e) = codex_exec::run_main(exec_cli, codex_linux_sandbox_exe).await {
|
||||
eprintln!("codex exec failed: {e}");
|
||||
// Do not attempt to bring changes on failure; leave worktree for inspection.
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-commit changes in the worktree if any
|
||||
let status_out = Command::new("git")
|
||||
.args(["status", "--porcelain"])
|
||||
.current_dir(&worktree_dir)
|
||||
.output()?;
|
||||
let status_text = String::from_utf8_lossy(&status_out.stdout);
|
||||
let had_changes = !status_text.trim().is_empty();
|
||||
if had_changes {
|
||||
// Stage and commit
|
||||
if !Command::new("git")
|
||||
.args(["add", "-A"])
|
||||
.current_dir(&worktree_dir)
|
||||
.status()?
|
||||
.success()
|
||||
{
|
||||
anyhow::bail!("git add failed in worktree");
|
||||
}
|
||||
let commit_message = format!("Codex concurrent: {prompt}");
|
||||
if !Command::new("git")
|
||||
.args(["commit", "-m", &commit_message])
|
||||
.current_dir(&worktree_dir)
|
||||
.status()?
|
||||
.success()
|
||||
{
|
||||
if !quiet {
|
||||
eprintln!("No commit created (maybe no changes)");
|
||||
}
|
||||
} else if !quiet {
|
||||
println!("Committed changes in worktree branch {branch}");
|
||||
}
|
||||
} else if !quiet {
|
||||
println!("No changes detected in worktree; skipping commit.");
|
||||
}
|
||||
|
||||
if !automerge {
|
||||
if !quiet {
|
||||
println!(
|
||||
"Auto-merge disabled; leaving changes in worktree {} on branch {}.",
|
||||
worktree_dir.display(),
|
||||
branch
|
||||
);
|
||||
println!(
|
||||
"You can review and manually merge from that branch into {current_branch} when ready."
|
||||
);
|
||||
println!("Summary: Auto-merge disabled.");
|
||||
}
|
||||
return Ok(ConcurrentRunResult {
|
||||
branch,
|
||||
worktree_dir,
|
||||
log_file,
|
||||
exec_exit_code,
|
||||
_had_changes: had_changes,
|
||||
_applied_changes: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Bring the changes into the main working tree as UNSTAGED modifications.
|
||||
// We generate a patch from the original HEAD to the worktree branch tip, then apply with 3-way merge.
|
||||
if !quiet {
|
||||
println!("Applying changes from {branch} onto {current_branch} as unstaged modifications");
|
||||
}
|
||||
let range = format!("{original_head}..{branch}");
|
||||
let mut diff_bytes =
|
||||
git_capture_stdout(&repo_root, &["diff", "--binary", "--full-index", &range])?;
|
||||
|
||||
// Fallback: if there is nothing in the commit range (e.g., commit didn't happen),
|
||||
// try to capture uncommitted changes from the worktree working tree.
|
||||
if diff_bytes.is_empty() && had_changes {
|
||||
// If we saw changes earlier but no commit diff was produced, fall back to working tree diff.
|
||||
// This captures unstaged changes relative to HEAD in the worktree.
|
||||
diff_bytes =
|
||||
git_capture_stdout(&worktree_dir, &["diff", "--binary", "--full-index", "HEAD"])?;
|
||||
}
|
||||
|
||||
if diff_bytes.is_empty() {
|
||||
if !quiet {
|
||||
println!("Summary: 0 changes detected.");
|
||||
}
|
||||
return Ok(ConcurrentRunResult {
|
||||
branch,
|
||||
worktree_dir,
|
||||
log_file,
|
||||
exec_exit_code,
|
||||
_had_changes: had_changes,
|
||||
_applied_changes: Some(0),
|
||||
});
|
||||
}
|
||||
|
||||
let changed_files = count_files_in_patch(&diff_bytes);
|
||||
|
||||
let mut child = Command::new("git")
|
||||
.arg("apply")
|
||||
.arg("-3")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::inherit())
|
||||
.stderr(Stdio::inherit())
|
||||
.current_dir(&repo_root)
|
||||
.spawn()
|
||||
.context("spawning git apply")?;
|
||||
if let Some(stdin) = child.stdin.as_mut() {
|
||||
stdin
|
||||
.write_all(&diff_bytes)
|
||||
.context("writing patch to git apply stdin")?;
|
||||
}
|
||||
let status = child.wait().context("waiting for git apply")?;
|
||||
if !status.success() {
|
||||
if !quiet {
|
||||
eprintln!(
|
||||
"Applying changes failed. You can manually inspect {} and apply diffs.",
|
||||
worktree_dir.display()
|
||||
);
|
||||
println!("Summary: Apply failed.");
|
||||
}
|
||||
} else {
|
||||
if !quiet {
|
||||
println!("Changes applied to working tree (unstaged).");
|
||||
println!("Summary: Applied {changed_files} files changed.");
|
||||
}
|
||||
|
||||
// Cleanup: remove the worktree and delete the temporary branch.
|
||||
if !quiet {
|
||||
println!(
|
||||
"Cleaning up worktree {} and branch {}",
|
||||
worktree_dir.display(),
|
||||
branch
|
||||
);
|
||||
}
|
||||
let worktree_path_str = worktree_dir.to_string_lossy().to_string();
|
||||
let remove_status = Command::new("git")
|
||||
.args(["worktree", "remove", &worktree_path_str])
|
||||
.current_dir(&repo_root)
|
||||
.status();
|
||||
match remove_status {
|
||||
Ok(s) if s.success() => {
|
||||
// removed
|
||||
}
|
||||
_ => {
|
||||
if !quiet {
|
||||
eprintln!("git worktree remove failed; retrying with --force");
|
||||
}
|
||||
let _ = Command::new("git")
|
||||
.args(["worktree", "remove", "--force", &worktree_path_str])
|
||||
.current_dir(&repo_root)
|
||||
.status();
|
||||
}
|
||||
}
|
||||
|
||||
let del_status = Command::new("git")
|
||||
.args(["branch", "-D", &branch])
|
||||
.current_dir(&repo_root)
|
||||
.status();
|
||||
if let Ok(s) = del_status {
|
||||
if !s.success() && !quiet {
|
||||
eprintln!("Failed to delete branch {branch}");
|
||||
}
|
||||
} else if !quiet {
|
||||
eprintln!("Error running git branch -D {branch}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ConcurrentRunResult {
|
||||
branch,
|
||||
worktree_dir,
|
||||
log_file,
|
||||
exec_exit_code,
|
||||
_had_changes: had_changes,
|
||||
_applied_changes: Some(changed_files),
|
||||
})
|
||||
}
|
||||
|
||||
/// A Send-friendly variant used for best-of-n: run quietly (logs redirected) and do not auto-merge.
|
||||
/// This intentionally avoids referencing non-Send types from codex-exec.
|
||||
pub async fn run_concurrent_flow_quiet_no_automerge(
|
||||
prompt: String,
|
||||
cli_config_overrides: CliConfigOverrides,
|
||||
_codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
) -> anyhow::Result<ConcurrentRunResult> {
|
||||
let cwd = std::env::current_dir()?;
|
||||
|
||||
let repo_root_str = git_output(&cwd, &["rev-parse", "--show-toplevel"]);
|
||||
let repo_root = match repo_root_str {
|
||||
Ok(p) => PathBuf::from(p),
|
||||
Err(err) => {
|
||||
eprintln!("Not inside a Git repo: {err}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
// Capture basic repo info (not used further in quiet/no-automerge flow).
|
||||
|
||||
let mut codex_home = compute_codex_home();
|
||||
codex_home.push("worktrees");
|
||||
let repo_name = repo_root
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "repo".to_string());
|
||||
codex_home.push(repo_name.clone());
|
||||
|
||||
let slug = slugify_prompt(&prompt, 64);
|
||||
let mut branch: String;
|
||||
let worktree_dir: PathBuf;
|
||||
// Serialize worktree creation to avoid git repo lock contention across tasks.
|
||||
{
|
||||
let semaphore = GIT_WORKTREE_ADD_SEMAPHORE.get_or_init(|| Semaphore::new(1));
|
||||
let _permit = semaphore.acquire().await.expect("semaphore closed");
|
||||
|
||||
let mut attempt: u32 = 1;
|
||||
loop {
|
||||
branch = if attempt == 1 {
|
||||
format!("codex/{slug}")
|
||||
} else {
|
||||
format!("codex/{slug}-{attempt}")
|
||||
};
|
||||
|
||||
let mut candidate_dir = codex_home.clone();
|
||||
candidate_dir.push(&branch);
|
||||
|
||||
if let Some(parent) = candidate_dir.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let worktree_path_str = candidate_dir.to_string_lossy().to_string();
|
||||
let add_status = TokioCommand::new("git")
|
||||
.arg("worktree")
|
||||
.arg("add")
|
||||
.arg("-b")
|
||||
.arg(&branch)
|
||||
.arg(&worktree_path_str)
|
||||
.current_dir(&repo_root)
|
||||
.status()
|
||||
.await?;
|
||||
if add_status.success() {
|
||||
worktree_dir = candidate_dir;
|
||||
break;
|
||||
}
|
||||
attempt += 1;
|
||||
if attempt > 50 {
|
||||
anyhow::bail!("Failed to create git worktree after multiple attempts");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run the CLI in quiet mode (logs redirected).
|
||||
let exe = std::env::current_exe()
|
||||
.map_err(|e| anyhow::anyhow!("failed to locate current executable: {e}"))?;
|
||||
|
||||
let mut logs_dir = compute_codex_home();
|
||||
logs_dir.push("logs");
|
||||
logs_dir.push(&repo_name);
|
||||
std::fs::create_dir_all(&logs_dir)?;
|
||||
|
||||
let sanitized_branch = branch.replace('/', "_");
|
||||
let log_path = logs_dir.join(format!("{sanitized_branch}.log"));
|
||||
let log_f = File::create(&log_path)?;
|
||||
let log_file = Some(log_path.clone());
|
||||
|
||||
let mut cmd = TokioCommand::new(exe);
|
||||
cmd.arg("exec")
|
||||
.arg("--full-auto")
|
||||
.arg("--cd")
|
||||
.arg(worktree_dir.as_os_str())
|
||||
.stdout(Stdio::from(log_f.try_clone()?))
|
||||
.stderr(Stdio::from(log_f));
|
||||
for ov in cli_config_overrides.raw_overrides.iter() {
|
||||
cmd.arg("-c").arg(ov);
|
||||
}
|
||||
cmd.arg(&prompt);
|
||||
|
||||
let status = cmd.status().await?;
|
||||
let exec_exit_code = status.code();
|
||||
|
||||
// Auto-commit changes in the worktree if any
|
||||
let status_out = TokioCommand::new("git")
|
||||
.args(["status", "--porcelain"])
|
||||
.current_dir(&worktree_dir)
|
||||
.output()
|
||||
.await?;
|
||||
let status_text = String::from_utf8_lossy(&status_out.stdout);
|
||||
let had_changes = !status_text.trim().is_empty();
|
||||
if had_changes {
|
||||
if !TokioCommand::new("git")
|
||||
.args(["add", "-A"])
|
||||
.current_dir(&worktree_dir)
|
||||
.status()
|
||||
.await?
|
||||
.success()
|
||||
{
|
||||
anyhow::bail!("git add failed in worktree");
|
||||
}
|
||||
let commit_message = format!("Codex concurrent: {prompt}");
|
||||
let _ = TokioCommand::new("git")
|
||||
.args(["commit", "-m", &commit_message])
|
||||
.current_dir(&worktree_dir)
|
||||
.status()
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(ConcurrentRunResult {
|
||||
branch,
|
||||
worktree_dir,
|
||||
log_file,
|
||||
exec_exit_code,
|
||||
_had_changes: had_changes,
|
||||
_applied_changes: None,
|
||||
})
|
||||
}
|
||||
@@ -17,6 +17,7 @@ use codex_tui::Cli as TuiCli;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::proto::ProtoCli;
|
||||
mod concurrent;
|
||||
|
||||
/// Codex CLI
|
||||
///
|
||||
@@ -32,6 +33,22 @@ struct MultitoolCli {
|
||||
#[clap(flatten)]
|
||||
pub config_overrides: CliConfigOverrides,
|
||||
|
||||
/// Experimental:Launch a concurrent task in a separate Git worktree using the given prompt.
|
||||
/// Creates worktree under $CODEX_HOME/worktrees/<repo>/codex/<slug> and runs `codex exec` in full-auto mode.
|
||||
#[arg(long = "concurrent", value_name = "PROMPT")]
|
||||
pub concurrent: Option<String>,
|
||||
|
||||
/// When using --concurrent, also attempt to auto-merge the resulting changes
|
||||
/// back into the current working tree as unstaged modifications via
|
||||
/// a 3-way git apply. Disable with --automerge=false.
|
||||
#[arg(long = "automerge", default_value_t = true, action = clap::ArgAction::Set)]
|
||||
pub automerge: bool,
|
||||
|
||||
/// Run the same --concurrent prompt N times in separate worktrees and keep them all.
|
||||
/// Intended to generate multiple candidate solutions without auto-merging.
|
||||
#[arg(long = "best-of-n", value_name = "N", default_value_t = 1)]
|
||||
pub best_of_n: usize,
|
||||
|
||||
#[clap(flatten)]
|
||||
interactive: TuiCli,
|
||||
|
||||
@@ -116,6 +133,87 @@ fn main() -> anyhow::Result<()> {
|
||||
async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let cli = MultitoolCli::parse();
|
||||
|
||||
// Handle --concurrent at the root level.
|
||||
if let Some(prompt) = cli.concurrent.clone() {
|
||||
if cli.subcommand.is_some() {
|
||||
eprintln!("--concurrent cannot be used together with a subcommand");
|
||||
std::process::exit(2);
|
||||
}
|
||||
let runs = if cli.best_of_n == 0 { 1 } else { cli.best_of_n };
|
||||
if runs > 1 {
|
||||
println!(
|
||||
"Running best-of-n with {runs} runs; auto-merge will be disabled and worktrees kept."
|
||||
);
|
||||
|
||||
// Launch all runs concurrently and collect results as they finish.
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
for _ in 0..runs {
|
||||
let prompt = prompt.clone();
|
||||
let overrides = cli.config_overrides.clone();
|
||||
let sandbox = codex_linux_sandbox_exe.clone();
|
||||
join_set.spawn(async move {
|
||||
concurrent::run_concurrent_flow_quiet_no_automerge(prompt, overrides, sandbox)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut results: Vec<concurrent::ConcurrentRunResult> = Vec::with_capacity(runs);
|
||||
while let Some(join_result) = join_set.join_next().await {
|
||||
match join_result {
|
||||
Ok(Ok(res)) => {
|
||||
println!(
|
||||
"task finished for branch: {}\n, directory: {}",
|
||||
res.branch,
|
||||
res.worktree_dir.display()
|
||||
);
|
||||
results.push(res);
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
eprintln!("concurrent task failed: {err}");
|
||||
}
|
||||
Err(join_err) => {
|
||||
eprintln!("failed to join concurrent task: {join_err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nBest-of-n summary:");
|
||||
for r in &results {
|
||||
let status = match r.exec_exit_code {
|
||||
Some(0) => "OK",
|
||||
Some(_code) => "FAIL",
|
||||
None => "OK",
|
||||
};
|
||||
let log = r
|
||||
.log_file
|
||||
.as_ref()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "<no log>".to_string());
|
||||
println!(
|
||||
"[{status}] branch={} worktree={} log={}",
|
||||
r.branch,
|
||||
r.worktree_dir.display(),
|
||||
log
|
||||
);
|
||||
}
|
||||
} else {
|
||||
concurrent::run_concurrent_flow(
|
||||
prompt,
|
||||
cli.config_overrides,
|
||||
codex_linux_sandbox_exe,
|
||||
cli.automerge,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if cli.best_of_n > 1 {
|
||||
eprintln!("--best-of-n requires --concurrent <PROMPT>");
|
||||
std::process::exit(2);
|
||||
}
|
||||
|
||||
match cli.subcommand {
|
||||
None => {
|
||||
let mut tui_cli = cli.interactive;
|
||||
|
||||
@@ -7,6 +7,7 @@ pub fn summarize_sandbox_policy(sandbox_policy: &SandboxPolicy) -> String {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots,
|
||||
network_access,
|
||||
include_default_writable_roots,
|
||||
} => {
|
||||
let mut summary = "workspace-write".to_string();
|
||||
if !writable_roots.is_empty() {
|
||||
@@ -19,6 +20,9 @@ pub fn summarize_sandbox_policy(sandbox_policy: &SandboxPolicy) -> String {
|
||||
.join(", ")
|
||||
));
|
||||
}
|
||||
if !*include_default_writable_roots {
|
||||
summary.push_str(" (exact writable roots)");
|
||||
}
|
||||
if *network_access {
|
||||
summary.push_str(" (network access enabled)");
|
||||
}
|
||||
|
||||
@@ -259,6 +259,8 @@ disk, but attempts to write a file or access the network will be blocked.
|
||||
|
||||
A more relaxed policy is `workspace-write`. When specified, the current working directory for the Codex task will be writable (as well as `$TMPDIR` on macOS). Note that the CLI defaults to using the directory where it was spawned as `cwd`, though this can be overridden using `--cwd/-C`.
|
||||
|
||||
On macOS (and soon Linux), all writable roots (including `cwd`) that contain a `.git/` folder _as an immediate child_ will configure the `.git/` folder to be read-only while the rest of the Git repository will be writable. This means that commands like `git commit` will fail, by default (as it entails writing to `.git/`), and will require Codex to ask for permission.
|
||||
|
||||
```toml
|
||||
# same as `--sandbox workspace-write`
|
||||
sandbox_mode = "workspace-write"
|
||||
|
||||
@@ -31,6 +31,7 @@ rand = "0.9"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serde_bytes = "0.11"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
strum_macros = "0.27.2"
|
||||
|
||||
@@ -48,6 +48,7 @@ use crate::error::SandboxErr;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::StdoutStream;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::exec_env::create_env;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
@@ -1372,7 +1373,7 @@ async fn run_compact_task(
|
||||
let mut retries = 0;
|
||||
|
||||
loop {
|
||||
let attempt_result = drain_to_completed(&sess, &prompt).await;
|
||||
let attempt_result = drain_to_completed(&sess, &sub_id, &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => break,
|
||||
@@ -1759,6 +1760,11 @@ async fn handle_container_exec_with_params(
|
||||
sess.ctrl_c.clone(),
|
||||
&sess.sandbox_policy,
|
||||
&sess.codex_linux_sandbox_exe,
|
||||
Some(StdoutStream {
|
||||
sub_id: sub_id.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tx_event: sess.tx_event.clone(),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1879,6 +1885,11 @@ async fn handle_sandbox_error(
|
||||
sess.ctrl_c.clone(),
|
||||
&sess.sandbox_policy,
|
||||
&sess.codex_linux_sandbox_exe,
|
||||
Some(StdoutStream {
|
||||
sub_id: sub_id.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tx_event: sess.tx_event.clone(),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1990,7 +2001,7 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
||||
})
|
||||
}
|
||||
|
||||
async fn drain_to_completed(sess: &Session, prompt: &Prompt) -> CodexResult<()> {
|
||||
async fn drain_to_completed(sess: &Session, sub_id: &str, prompt: &Prompt) -> CodexResult<()> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
@@ -2000,7 +2011,32 @@ async fn drain_to_completed(sess: &Session, prompt: &Prompt) -> CodexResult<()>
|
||||
));
|
||||
};
|
||||
match event {
|
||||
Ok(ResponseEvent::Completed { .. }) => return Ok(()),
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
// Record only to in-memory conversation history; avoid state snapshot.
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.history.record_items(std::slice::from_ref(&item));
|
||||
}
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
token_usage,
|
||||
}) => {
|
||||
let token_usage = match token_usage {
|
||||
Some(usage) => usage,
|
||||
None => {
|
||||
return Err(CodexErr::Stream(
|
||||
"token_usage was None in ResponseEvent::Completed".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
sess.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::TokenCount(token_usage),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return Ok(());
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
||||
@@ -356,6 +356,7 @@ impl ConfigToml {
|
||||
Some(s) => SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: s.writable_roots.clone(),
|
||||
network_access: s.network_access,
|
||||
include_default_writable_roots: true,
|
||||
},
|
||||
None => SandboxPolicy::new_workspace_write_policy(),
|
||||
},
|
||||
@@ -727,6 +728,7 @@ writable_roots = [
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![PathBuf::from("/tmp")],
|
||||
network_access: false,
|
||||
include_default_writable_roots: true,
|
||||
},
|
||||
sandbox_workspace_write_cfg.derive_sandbox_policy(sandbox_mode_override)
|
||||
);
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_channel::Sender;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::BufReader;
|
||||
@@ -19,10 +20,15 @@ use tokio::sync::Notify;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandOutputDeltaEvent;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::seatbelt::spawn_command_under_seatbelt;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
use serde_bytes::ByteBuf;
|
||||
|
||||
// Maximum we send for each stream, which is either:
|
||||
// - 10KiB OR
|
||||
@@ -56,18 +62,26 @@ pub enum SandboxType {
|
||||
LinuxSeccomp,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StdoutStream {
|
||||
pub sub_id: String,
|
||||
pub call_id: String,
|
||||
pub tx_event: Sender<Event>,
|
||||
}
|
||||
|
||||
pub async fn process_exec_tool_call(
|
||||
params: ExecParams,
|
||||
sandbox_type: SandboxType,
|
||||
ctrl_c: Arc<Notify>,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
codex_linux_sandbox_exe: &Option<PathBuf>,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<ExecToolCallOutput> {
|
||||
let start = Instant::now();
|
||||
|
||||
let raw_output_result: std::result::Result<RawExecToolCallOutput, CodexErr> = match sandbox_type
|
||||
{
|
||||
SandboxType::None => exec(params, sandbox_policy, ctrl_c).await,
|
||||
SandboxType::None => exec(params, sandbox_policy, ctrl_c, stdout_stream.clone()).await,
|
||||
SandboxType::MacosSeatbelt => {
|
||||
let ExecParams {
|
||||
command,
|
||||
@@ -83,7 +97,7 @@ pub async fn process_exec_tool_call(
|
||||
env,
|
||||
)
|
||||
.await?;
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms).await
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms, stdout_stream.clone()).await
|
||||
}
|
||||
SandboxType::LinuxSeccomp => {
|
||||
let ExecParams {
|
||||
@@ -106,7 +120,7 @@ pub async fn process_exec_tool_call(
|
||||
)
|
||||
.await?;
|
||||
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms).await
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms, stdout_stream).await
|
||||
}
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
@@ -233,6 +247,7 @@ async fn exec(
|
||||
}: ExecParams,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
ctrl_c: Arc<Notify>,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<RawExecToolCallOutput> {
|
||||
let (program, args) = command.split_first().ok_or_else(|| {
|
||||
CodexErr::Io(io::Error::new(
|
||||
@@ -251,7 +266,7 @@ async fn exec(
|
||||
env,
|
||||
)
|
||||
.await?;
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms).await
|
||||
consume_truncated_output(child, ctrl_c, timeout_ms, stdout_stream).await
|
||||
}
|
||||
|
||||
/// Consumes the output of a child process, truncating it so it is suitable for
|
||||
@@ -260,6 +275,7 @@ pub(crate) async fn consume_truncated_output(
|
||||
mut child: Child,
|
||||
ctrl_c: Arc<Notify>,
|
||||
timeout_ms: Option<u64>,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<RawExecToolCallOutput> {
|
||||
// Both stdout and stderr were configured with `Stdio::piped()`
|
||||
// above, therefore `take()` should normally return `Some`. If it doesn't
|
||||
@@ -280,11 +296,15 @@ pub(crate) async fn consume_truncated_output(
|
||||
BufReader::new(stdout_reader),
|
||||
MAX_STREAM_OUTPUT,
|
||||
MAX_STREAM_OUTPUT_LINES,
|
||||
stdout_stream.clone(),
|
||||
false,
|
||||
));
|
||||
let stderr_handle = tokio::spawn(read_capped(
|
||||
BufReader::new(stderr_reader),
|
||||
MAX_STREAM_OUTPUT,
|
||||
MAX_STREAM_OUTPUT_LINES,
|
||||
stdout_stream.clone(),
|
||||
true,
|
||||
));
|
||||
|
||||
let interrupted = ctrl_c.notified();
|
||||
@@ -318,10 +338,12 @@ pub(crate) async fn consume_truncated_output(
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_capped<R: AsyncRead + Unpin>(
|
||||
async fn read_capped<R: AsyncRead + Unpin + Send + 'static>(
|
||||
mut reader: R,
|
||||
max_output: usize,
|
||||
max_lines: usize,
|
||||
stream: Option<StdoutStream>,
|
||||
is_stderr: bool,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut buf = Vec::with_capacity(max_output.min(8 * 1024));
|
||||
let mut tmp = [0u8; 8192];
|
||||
@@ -335,6 +357,25 @@ async fn read_capped<R: AsyncRead + Unpin>(
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(stream) = &stream {
|
||||
let chunk = tmp[..n].to_vec();
|
||||
let msg = EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
|
||||
call_id: stream.call_id.clone(),
|
||||
stream: if is_stderr {
|
||||
ExecOutputStream::Stderr
|
||||
} else {
|
||||
ExecOutputStream::Stdout
|
||||
},
|
||||
chunk: ByteBuf::from(chunk),
|
||||
});
|
||||
let event = Event {
|
||||
id: stream.sub_id.clone(),
|
||||
msg,
|
||||
};
|
||||
#[allow(clippy::let_unit_value)]
|
||||
let _ = stream.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
// Copy into the buffer only while we still have byte and line budget.
|
||||
if remaining_bytes > 0 && remaining_lines > 0 {
|
||||
let mut copy_len = 0;
|
||||
|
||||
@@ -13,6 +13,7 @@ use std::time::Duration;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_bytes::ByteBuf;
|
||||
use strum_macros::Display;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -179,9 +180,29 @@ pub enum SandboxPolicy {
|
||||
/// default.
|
||||
#[serde(default)]
|
||||
network_access: bool,
|
||||
|
||||
/// When set to `true`, will include defaults like the current working
|
||||
/// directory and TMPDIR (on macOS). When `false`, only `writable_roots`
|
||||
/// are used. (Mainly used for testing.)
|
||||
#[serde(default = "default_true")]
|
||||
include_default_writable_roots: bool,
|
||||
},
|
||||
}
|
||||
|
||||
/// A writable root path accompanied by a list of subpaths that should remain
|
||||
/// read‑only even when the root is writable. This is primarily used to ensure
|
||||
/// top‑level VCS metadata directories (e.g. `.git`) under a writable root are
|
||||
/// not modified by the agent.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct WritableRoot {
|
||||
pub root: PathBuf,
|
||||
pub read_only_subpaths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl FromStr for SandboxPolicy {
|
||||
type Err = serde_json::Error;
|
||||
|
||||
@@ -203,6 +224,7 @@ impl SandboxPolicy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
include_default_writable_roots: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,27 +250,51 @@ impl SandboxPolicy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the list of writable roots that should be passed down to the
|
||||
/// Landlock rules installer, tailored to the current working directory.
|
||||
pub fn get_writable_roots_with_cwd(&self, cwd: &Path) -> Vec<PathBuf> {
|
||||
/// Returns the list of writable roots (tailored to the current working
|
||||
/// directory) together with subpaths that should remain read‑only under
|
||||
/// each writable root.
|
||||
pub fn get_writable_roots_with_cwd(&self, cwd: &Path) -> Vec<WritableRoot> {
|
||||
match self {
|
||||
SandboxPolicy::DangerFullAccess => Vec::new(),
|
||||
SandboxPolicy::ReadOnly => Vec::new(),
|
||||
SandboxPolicy::WorkspaceWrite { writable_roots, .. } => {
|
||||
let mut roots = writable_roots.clone();
|
||||
roots.push(cwd.to_path_buf());
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots,
|
||||
include_default_writable_roots,
|
||||
..
|
||||
} => {
|
||||
// Start from explicitly configured writable roots.
|
||||
let mut roots: Vec<PathBuf> = writable_roots.clone();
|
||||
|
||||
// Also include the per-user tmp dir on macOS.
|
||||
// Note this is added dynamically rather than storing it in
|
||||
// writable_roots because writable_roots contains only static
|
||||
// values deserialized from the config file.
|
||||
if cfg!(target_os = "macos") {
|
||||
if let Some(tmpdir) = std::env::var_os("TMPDIR") {
|
||||
roots.push(PathBuf::from(tmpdir));
|
||||
// Optionally include defaults (cwd and TMPDIR on macOS).
|
||||
if *include_default_writable_roots {
|
||||
roots.push(cwd.to_path_buf());
|
||||
|
||||
// Also include the per-user tmp dir on macOS.
|
||||
// Note this is added dynamically rather than storing it in
|
||||
// `writable_roots` because `writable_roots` contains only static
|
||||
// values deserialized from the config file.
|
||||
if cfg!(target_os = "macos") {
|
||||
if let Some(tmpdir) = std::env::var_os("TMPDIR") {
|
||||
roots.push(PathBuf::from(tmpdir));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each root, compute subpaths that should remain read-only.
|
||||
roots
|
||||
.into_iter()
|
||||
.map(|writable_root| {
|
||||
let mut subpaths = Vec::new();
|
||||
let top_level_git = writable_root.join(".git");
|
||||
if top_level_git.is_dir() {
|
||||
subpaths.push(top_level_git);
|
||||
}
|
||||
WritableRoot {
|
||||
root: writable_root,
|
||||
read_only_subpaths: subpaths,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -323,6 +369,9 @@ pub enum EventMsg {
|
||||
/// Notification that the server is about to execute a command.
|
||||
ExecCommandBegin(ExecCommandBeginEvent),
|
||||
|
||||
/// Incremental chunk of output from a running command.
|
||||
ExecCommandOutputDelta(ExecCommandOutputDeltaEvent),
|
||||
|
||||
ExecCommandEnd(ExecCommandEndEvent),
|
||||
|
||||
ExecApprovalRequest(ExecApprovalRequestEvent),
|
||||
@@ -476,6 +525,24 @@ pub struct ExecCommandEndEvent {
|
||||
pub exit_code: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ExecOutputStream {
|
||||
Stdout,
|
||||
Stderr,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ExecCommandOutputDeltaEvent {
|
||||
/// Identifier for the ExecCommandBegin that produced this chunk.
|
||||
pub call_id: String,
|
||||
/// Which stream produced this chunk.
|
||||
pub stream: ExecOutputStream,
|
||||
/// Raw bytes from the stream (may not be valid UTF-8).
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub chunk: ByteBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ExecApprovalRequestEvent {
|
||||
/// Identifier for the associated exec call, if available.
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::path::PathBuf;
|
||||
use tokio::process::Child;
|
||||
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
|
||||
@@ -20,10 +21,11 @@ pub async fn spawn_command_under_seatbelt(
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: PathBuf,
|
||||
stdio_policy: StdioPolicy,
|
||||
env: HashMap<String, String>,
|
||||
mut env: HashMap<String, String>,
|
||||
) -> std::io::Result<Child> {
|
||||
let args = create_seatbelt_command_args(command, sandbox_policy, &cwd);
|
||||
let arg0 = None;
|
||||
env.insert(CODEX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
|
||||
spawn_child_async(
|
||||
PathBuf::from(MACOS_PATH_TO_SEATBELT_EXECUTABLE),
|
||||
args,
|
||||
@@ -50,16 +52,38 @@ fn create_seatbelt_command_args(
|
||||
)
|
||||
} else {
|
||||
let writable_roots = sandbox_policy.get_writable_roots_with_cwd(cwd);
|
||||
let (writable_folder_policies, cli_args): (Vec<String>, Vec<String>) = writable_roots
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, root)| {
|
||||
let param_name = format!("WRITABLE_ROOT_{index}");
|
||||
let policy: String = format!("(subpath (param \"{param_name}\"))");
|
||||
let cli_arg = format!("-D{param_name}={}", root.to_string_lossy());
|
||||
(policy, cli_arg)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
let mut writable_folder_policies: Vec<String> = Vec::new();
|
||||
let mut cli_args: Vec<String> = Vec::new();
|
||||
|
||||
for (index, wr) in writable_roots.iter().enumerate() {
|
||||
// Canonicalize to avoid mismatches like /var vs /private/var on macOS.
|
||||
let canonical_root = wr.root.canonicalize().unwrap_or_else(|_| wr.root.clone());
|
||||
let root_param = format!("WRITABLE_ROOT_{index}");
|
||||
cli_args.push(format!(
|
||||
"-D{root_param}={}",
|
||||
canonical_root.to_string_lossy()
|
||||
));
|
||||
|
||||
if wr.read_only_subpaths.is_empty() {
|
||||
writable_folder_policies.push(format!("(subpath (param \"{root_param}\"))"));
|
||||
} else {
|
||||
// Add parameters for each read-only subpath and generate
|
||||
// the `(require-not ...)` clauses.
|
||||
let mut require_parts: Vec<String> = Vec::new();
|
||||
require_parts.push(format!("(subpath (param \"{root_param}\"))"));
|
||||
for (subpath_index, ro) in wr.read_only_subpaths.iter().enumerate() {
|
||||
let canonical_ro = ro.canonicalize().unwrap_or_else(|_| ro.clone());
|
||||
let ro_param = format!("WRITABLE_ROOT_{index}_RO_{subpath_index}");
|
||||
cli_args.push(format!("-D{ro_param}={}", canonical_ro.to_string_lossy()));
|
||||
require_parts
|
||||
.push(format!("(require-not (subpath (param \"{ro_param}\")))"));
|
||||
}
|
||||
let policy_component = format!("(require-all {} )", require_parts.join(" "));
|
||||
writable_folder_policies.push(policy_component);
|
||||
}
|
||||
}
|
||||
|
||||
if writable_folder_policies.is_empty() {
|
||||
("".to_string(), Vec::<String>::new())
|
||||
} else {
|
||||
@@ -88,9 +112,201 @@ fn create_seatbelt_command_args(
|
||||
let full_policy = format!(
|
||||
"{MACOS_SEATBELT_BASE_POLICY}\n{file_read_policy}\n{file_write_policy}\n{network_policy}"
|
||||
);
|
||||
|
||||
let mut seatbelt_args: Vec<String> = vec!["-p".to_string(), full_policy];
|
||||
seatbelt_args.extend(extra_cli_args);
|
||||
seatbelt_args.push("--".to_string());
|
||||
seatbelt_args.extend(command);
|
||||
seatbelt_args
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![expect(clippy::expect_used)]
|
||||
use super::MACOS_SEATBELT_BASE_POLICY;
|
||||
use super::create_seatbelt_command_args;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn create_seatbelt_args_with_read_only_git_subpath() {
|
||||
// Create a temporary workspace with two writable roots: one containing
|
||||
// a top-level .git directory and one without it.
|
||||
let tmp = TempDir::new().expect("tempdir");
|
||||
let PopulatedTmp {
|
||||
root_with_git,
|
||||
root_without_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
root_without_git_canon,
|
||||
} = populate_tmpdir(tmp.path());
|
||||
|
||||
// Build a policy that only includes the two test roots as writable and
|
||||
// does not automatically include defaults like cwd or TMPDIR.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![root_with_git.clone(), root_without_git.clone()],
|
||||
network_access: false,
|
||||
include_default_writable_roots: false,
|
||||
};
|
||||
|
||||
let args = create_seatbelt_command_args(
|
||||
vec!["/bin/echo".to_string(), "hello".to_string()],
|
||||
&policy,
|
||||
tmp.path(),
|
||||
);
|
||||
|
||||
// Build the expected policy text using a raw string for readability.
|
||||
// Note that the policy includes:
|
||||
// - the base policy,
|
||||
// - read-only access to the filesystem,
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
|
||||
let expected_policy = format!(
|
||||
r#"{MACOS_SEATBELT_BASE_POLICY}
|
||||
; allow read-only file operations
|
||||
(allow file-read*)
|
||||
(allow file-write*
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ) (subpath (param "WRITABLE_ROOT_1"))
|
||||
)
|
||||
"#,
|
||||
);
|
||||
|
||||
let expected_args = vec![
|
||||
"-p".to_string(),
|
||||
expected_policy,
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0={}",
|
||||
root_with_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_0={}",
|
||||
root_with_git_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_1={}",
|
||||
root_without_git_canon.to_string_lossy()
|
||||
),
|
||||
"--".to_string(),
|
||||
"/bin/echo".to_string(),
|
||||
"hello".to_string(),
|
||||
];
|
||||
|
||||
assert_eq!(args, expected_args);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_seatbelt_args_for_cwd_as_git_repo() {
|
||||
// Create a temporary workspace with two writable roots: one containing
|
||||
// a top-level .git directory and one without it.
|
||||
let tmp = TempDir::new().expect("tempdir");
|
||||
let PopulatedTmp {
|
||||
root_with_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
..
|
||||
} = populate_tmpdir(tmp.path());
|
||||
|
||||
// Build a policy that does not specify any writable_roots, but does
|
||||
// use the default ones (cwd and TMPDIR) and verifies the `.git` check
|
||||
// is done properly for cwd.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
include_default_writable_roots: true,
|
||||
};
|
||||
|
||||
let args = create_seatbelt_command_args(
|
||||
vec!["/bin/echo".to_string(), "hello".to_string()],
|
||||
&policy,
|
||||
root_with_git.as_path(),
|
||||
);
|
||||
|
||||
let tmpdir_env_var = if cfg!(target_os = "macos") {
|
||||
std::env::var("TMPDIR")
|
||||
.ok()
|
||||
.map(PathBuf::from)
|
||||
.and_then(|p| p.canonicalize().ok())
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tempdir_policy_entry = if tmpdir_env_var.is_some() {
|
||||
" (subpath (param \"WRITABLE_ROOT_1\"))"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
// Build the expected policy text using a raw string for readability.
|
||||
// Note that the policy includes:
|
||||
// - the base policy,
|
||||
// - read-only access to the filesystem,
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
|
||||
let expected_policy = format!(
|
||||
r#"{MACOS_SEATBELT_BASE_POLICY}
|
||||
; allow read-only file operations
|
||||
(allow file-read*)
|
||||
(allow file-write*
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ){tempdir_policy_entry}
|
||||
)
|
||||
"#,
|
||||
);
|
||||
|
||||
let mut expected_args = vec![
|
||||
"-p".to_string(),
|
||||
expected_policy,
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0={}",
|
||||
root_with_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_0={}",
|
||||
root_with_git_git_canon.to_string_lossy()
|
||||
),
|
||||
];
|
||||
|
||||
if let Some(p) = tmpdir_env_var {
|
||||
expected_args.push(format!("-DWRITABLE_ROOT_1={p}"));
|
||||
}
|
||||
|
||||
expected_args.extend(vec![
|
||||
"--".to_string(),
|
||||
"/bin/echo".to_string(),
|
||||
"hello".to_string(),
|
||||
]);
|
||||
|
||||
assert_eq!(args, expected_args);
|
||||
}
|
||||
|
||||
struct PopulatedTmp {
|
||||
root_with_git: PathBuf,
|
||||
root_without_git: PathBuf,
|
||||
root_with_git_canon: PathBuf,
|
||||
root_with_git_git_canon: PathBuf,
|
||||
root_without_git_canon: PathBuf,
|
||||
}
|
||||
|
||||
fn populate_tmpdir(tmp: &Path) -> PopulatedTmp {
|
||||
let root_with_git = tmp.join("with_git");
|
||||
let root_without_git = tmp.join("no_git");
|
||||
fs::create_dir_all(&root_with_git).expect("create with_git");
|
||||
fs::create_dir_all(&root_without_git).expect("create no_git");
|
||||
fs::create_dir_all(root_with_git.join(".git")).expect("create .git");
|
||||
|
||||
// Ensure we have canonical paths for -D parameter matching.
|
||||
let root_with_git_canon = root_with_git.canonicalize().expect("canonicalize with_git");
|
||||
let root_with_git_git_canon = root_with_git_canon.join(".git");
|
||||
let root_without_git_canon = root_without_git
|
||||
.canonicalize()
|
||||
.expect("canonicalize no_git");
|
||||
PopulatedTmp {
|
||||
root_with_git,
|
||||
root_without_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
root_without_git_canon,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,6 +220,7 @@ mod tests {
|
||||
Arc::new(Notify::new()),
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -17,6 +17,11 @@ use crate::protocol::SandboxPolicy;
|
||||
/// attributes, so this may change in the future.
|
||||
pub const CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR: &str = "CODEX_SANDBOX_NETWORK_DISABLED";
|
||||
|
||||
/// Should be set when the process is spawned under a sandbox. Currently, the
|
||||
/// value is "seatbelt" for macOS, but it may change in the future to
|
||||
/// accommodate sandboxing configuration and other sandboxing mechanisms.
|
||||
pub const CODEX_SANDBOX_ENV_VAR: &str = "CODEX_SANDBOX";
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum StdioPolicy {
|
||||
RedirectForShellTool,
|
||||
|
||||
143
codex-rs/core/tests/exec_stream_events.rs
Normal file
143
codex-rs/core/tests/exec_stream_events.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
#![cfg(unix)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_channel::Receiver;
|
||||
use codex_core::exec::ExecParams;
|
||||
use codex_core::exec::SandboxType;
|
||||
use codex_core::exec::StdoutStream;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecCommandOutputDeltaEvent;
|
||||
use codex_core::protocol::ExecOutputStream;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
fn collect_stdout_events(rx: Receiver<Event>) -> Vec<u8> {
|
||||
let mut out = Vec::new();
|
||||
while let Ok(ev) = rx.try_recv() {
|
||||
if let EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
|
||||
stream: ExecOutputStream::Stdout,
|
||||
chunk,
|
||||
..
|
||||
}) = ev.msg
|
||||
{
|
||||
out.extend_from_slice(&chunk);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exec_stdout_stream_events_echo() {
|
||||
let (tx, rx) = async_channel::unbounded::<Event>();
|
||||
|
||||
let stdout_stream = StdoutStream {
|
||||
sub_id: "test-sub".to_string(),
|
||||
call_id: "call-1".to_string(),
|
||||
tx_event: tx,
|
||||
};
|
||||
|
||||
let cmd = vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
// Use printf for predictable behavior across shells
|
||||
"printf 'hello-world\n'".to_string(),
|
||||
];
|
||||
|
||||
let params = ExecParams {
|
||||
command: cmd,
|
||||
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
||||
timeout_ms: Some(5_000),
|
||||
env: HashMap::new(),
|
||||
};
|
||||
|
||||
let ctrl_c = Arc::new(Notify::new());
|
||||
let policy = SandboxPolicy::new_read_only_policy();
|
||||
|
||||
let result = process_exec_tool_call(
|
||||
params,
|
||||
SandboxType::None,
|
||||
ctrl_c,
|
||||
&policy,
|
||||
&None,
|
||||
Some(stdout_stream),
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = match result {
|
||||
Ok(r) => r,
|
||||
Err(e) => panic!("process_exec_tool_call failed: {e}"),
|
||||
};
|
||||
|
||||
assert_eq!(result.exit_code, 0);
|
||||
assert_eq!(result.stdout, "hello-world\n");
|
||||
|
||||
let streamed = collect_stdout_events(rx);
|
||||
// We should have received at least the same contents (possibly in one chunk)
|
||||
assert_eq!(String::from_utf8_lossy(&streamed), "hello-world\n");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exec_stderr_stream_events_echo() {
|
||||
let (tx, rx) = async_channel::unbounded::<Event>();
|
||||
|
||||
let stdout_stream = StdoutStream {
|
||||
sub_id: "test-sub".to_string(),
|
||||
call_id: "call-2".to_string(),
|
||||
tx_event: tx,
|
||||
};
|
||||
|
||||
let cmd = vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
// Write to stderr explicitly
|
||||
"printf 'oops\n' 1>&2".to_string(),
|
||||
];
|
||||
|
||||
let params = ExecParams {
|
||||
command: cmd,
|
||||
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
||||
timeout_ms: Some(5_000),
|
||||
env: HashMap::new(),
|
||||
};
|
||||
|
||||
let ctrl_c = Arc::new(Notify::new());
|
||||
let policy = SandboxPolicy::new_read_only_policy();
|
||||
|
||||
let result = process_exec_tool_call(
|
||||
params,
|
||||
SandboxType::None,
|
||||
ctrl_c,
|
||||
&policy,
|
||||
&None,
|
||||
Some(stdout_stream),
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = match result {
|
||||
Ok(r) => r,
|
||||
Err(e) => panic!("process_exec_tool_call failed: {e}"),
|
||||
};
|
||||
|
||||
assert_eq!(result.exit_code, 0);
|
||||
assert_eq!(result.stdout, "");
|
||||
assert_eq!(result.stderr, "oops\n");
|
||||
|
||||
// Collect only stderr delta events
|
||||
let mut err = Vec::new();
|
||||
while let Ok(ev) = rx.try_recv() {
|
||||
if let EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
|
||||
stream: ExecOutputStream::Stderr,
|
||||
chunk,
|
||||
..
|
||||
}) = ev.msg
|
||||
{
|
||||
err.extend_from_slice(&chunk);
|
||||
}
|
||||
}
|
||||
assert_eq!(String::from_utf8_lossy(&err), "oops\n");
|
||||
}
|
||||
195
codex-rs/core/tests/sandbox.rs
Normal file
195
codex-rs/core/tests/sandbox.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
#![cfg(target_os = "macos")]
|
||||
#![expect(clippy::expect_used)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_core::seatbelt::spawn_command_under_seatbelt;
|
||||
use codex_core::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use codex_core::spawn::StdioPolicy;
|
||||
use tempfile::TempDir;
|
||||
|
||||
struct TestScenario {
|
||||
repo_parent: PathBuf,
|
||||
file_outside_repo: PathBuf,
|
||||
repo_root: PathBuf,
|
||||
file_in_repo_root: PathBuf,
|
||||
file_in_dot_git_dir: PathBuf,
|
||||
}
|
||||
|
||||
struct TestExpectations {
|
||||
file_outside_repo_is_writable: bool,
|
||||
file_in_repo_root_is_writable: bool,
|
||||
file_in_dot_git_dir_is_writable: bool,
|
||||
}
|
||||
|
||||
impl TestScenario {
|
||||
async fn run_test(&self, policy: &SandboxPolicy, expectations: TestExpectations) {
|
||||
if std::env::var(CODEX_SANDBOX_ENV_VAR) == Ok("seatbelt".to_string()) {
|
||||
eprintln!("{CODEX_SANDBOX_ENV_VAR} is set to 'seatbelt', skipping test.");
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
touch(&self.file_outside_repo, policy).await,
|
||||
expectations.file_outside_repo_is_writable
|
||||
);
|
||||
assert_eq!(
|
||||
self.file_outside_repo.exists(),
|
||||
expectations.file_outside_repo_is_writable
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
touch(&self.file_in_repo_root, policy).await,
|
||||
expectations.file_in_repo_root_is_writable
|
||||
);
|
||||
assert_eq!(
|
||||
self.file_in_repo_root.exists(),
|
||||
expectations.file_in_repo_root_is_writable
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
touch(&self.file_in_dot_git_dir, policy).await,
|
||||
expectations.file_in_dot_git_dir_is_writable
|
||||
);
|
||||
assert_eq!(
|
||||
self.file_in_dot_git_dir.exists(),
|
||||
expectations.file_in_dot_git_dir_is_writable
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// If the user has added a workspace root that is not a Git repo root, then
|
||||
/// the user has to specify `--skip-git-repo-check` or go through some
|
||||
/// interstitial that indicates they are taking on some risk because Git
|
||||
/// cannot be used to backup their work before the agent begins.
|
||||
///
|
||||
/// Because the user has agreed to this risk, we do not try find all .git
|
||||
/// folders in the workspace and block them (though we could change our
|
||||
/// position on this in the future).
|
||||
#[tokio::test]
|
||||
async fn if_parent_of_repo_is_writable_then_dot_git_folder_is_writable() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let test_scenario = create_test_scenario(&tmp);
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![test_scenario.repo_parent.clone()],
|
||||
network_access: false,
|
||||
include_default_writable_roots: false,
|
||||
};
|
||||
|
||||
test_scenario
|
||||
.run_test(
|
||||
&policy,
|
||||
TestExpectations {
|
||||
file_outside_repo_is_writable: true,
|
||||
file_in_repo_root_is_writable: true,
|
||||
file_in_dot_git_dir_is_writable: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// When the writable root is the root of a Git repository (as evidenced by the
|
||||
/// presence of a .git folder), then the .git folder should be read-only if
|
||||
/// the policy is `WorkspaceWrite`.
|
||||
#[tokio::test]
|
||||
async fn if_git_repo_is_writable_root_then_dot_git_folder_is_read_only() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let test_scenario = create_test_scenario(&tmp);
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![test_scenario.repo_root.clone()],
|
||||
network_access: false,
|
||||
include_default_writable_roots: false,
|
||||
};
|
||||
|
||||
test_scenario
|
||||
.run_test(
|
||||
&policy,
|
||||
TestExpectations {
|
||||
file_outside_repo_is_writable: false,
|
||||
file_in_repo_root_is_writable: true,
|
||||
file_in_dot_git_dir_is_writable: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Under DangerFullAccess, all writes should be permitted anywhere on disk,
|
||||
/// including inside the .git folder.
|
||||
#[tokio::test]
|
||||
async fn danger_full_access_allows_all_writes() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let test_scenario = create_test_scenario(&tmp);
|
||||
let policy = SandboxPolicy::DangerFullAccess;
|
||||
|
||||
test_scenario
|
||||
.run_test(
|
||||
&policy,
|
||||
TestExpectations {
|
||||
file_outside_repo_is_writable: true,
|
||||
file_in_repo_root_is_writable: true,
|
||||
file_in_dot_git_dir_is_writable: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Under ReadOnly, writes should not be permitted anywhere on disk.
|
||||
#[tokio::test]
|
||||
async fn read_only_forbids_all_writes() {
|
||||
let tmp = TempDir::new().expect("should be able to create temp dir");
|
||||
let test_scenario = create_test_scenario(&tmp);
|
||||
let policy = SandboxPolicy::ReadOnly;
|
||||
|
||||
test_scenario
|
||||
.run_test(
|
||||
&policy,
|
||||
TestExpectations {
|
||||
file_outside_repo_is_writable: false,
|
||||
file_in_repo_root_is_writable: false,
|
||||
file_in_dot_git_dir_is_writable: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn create_test_scenario(tmp: &TempDir) -> TestScenario {
|
||||
let repo_parent = tmp.path().to_path_buf();
|
||||
let repo_root = repo_parent.join("repo");
|
||||
let dot_git_dir = repo_root.join(".git");
|
||||
|
||||
std::fs::create_dir(&repo_root).expect("should be able to create repo root");
|
||||
std::fs::create_dir(&dot_git_dir).expect("should be able to create .git dir");
|
||||
|
||||
TestScenario {
|
||||
file_outside_repo: repo_parent.join("outside.txt"),
|
||||
repo_parent,
|
||||
file_in_repo_root: repo_root.join("repo_file.txt"),
|
||||
repo_root,
|
||||
file_in_dot_git_dir: dot_git_dir.join("dot_git_file.txt"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Note that `path` must be absolute.
|
||||
async fn touch(path: &Path, policy: &SandboxPolicy) -> bool {
|
||||
assert!(path.is_absolute(), "Path must be absolute: {path:?}");
|
||||
let mut child = spawn_command_under_seatbelt(
|
||||
vec![
|
||||
"/usr/bin/touch".to_string(),
|
||||
path.to_string_lossy().to_string(),
|
||||
],
|
||||
policy,
|
||||
std::env::current_dir().expect("should be able to get current dir"),
|
||||
StdioPolicy::RedirectForShellTool,
|
||||
HashMap::new(),
|
||||
)
|
||||
.await
|
||||
.expect("should be able to spawn command under seatbelt");
|
||||
child
|
||||
.wait()
|
||||
.await
|
||||
.expect("should be able to wait for child process")
|
||||
.success()
|
||||
}
|
||||
@@ -239,6 +239,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
cwd.to_string_lossy(),
|
||||
);
|
||||
}
|
||||
EventMsg::ExecCommandOutputDelta(_) => {}
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
stdout,
|
||||
|
||||
@@ -36,7 +36,11 @@ pub(crate) fn apply_sandbox_policy_to_current_thread(
|
||||
}
|
||||
|
||||
if !sandbox_policy.has_full_disk_write_access() {
|
||||
let writable_roots = sandbox_policy.get_writable_roots_with_cwd(cwd);
|
||||
let writable_roots = sandbox_policy
|
||||
.get_writable_roots_with_cwd(cwd)
|
||||
.into_iter()
|
||||
.map(|writable_root| writable_root.root)
|
||||
.collect();
|
||||
install_filesystem_landlock_rules_on_current_thread(writable_roots)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ async fn run_cmd(cmd: &[&str], writable_roots: &[PathBuf], timeout_ms: u64) {
|
||||
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: writable_roots.to_vec(),
|
||||
network_access: false,
|
||||
include_default_writable_roots: true,
|
||||
};
|
||||
let sandbox_program = env!("CARGO_BIN_EXE_codex-linux-sandbox");
|
||||
let codex_linux_sandbox_exe = Some(PathBuf::from(sandbox_program));
|
||||
@@ -59,6 +60,7 @@ async fn run_cmd(cmd: &[&str], writable_roots: &[PathBuf], timeout_ms: u64) {
|
||||
ctrl_c,
|
||||
&sandbox_policy,
|
||||
&codex_linux_sandbox_exe,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -149,6 +151,7 @@ async fn assert_network_blocked(cmd: &[&str]) {
|
||||
ctrl_c,
|
||||
&sandbox_policy,
|
||||
&codex_linux_sandbox_exe,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -686,6 +686,7 @@ LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
background: white;
|
||||
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
|
||||
}
|
||||
.inner-container {
|
||||
@@ -703,6 +704,7 @@ LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
display: flex;
|
||||
margin-top: 15vh;
|
||||
}
|
||||
.svg-wrapper {
|
||||
position: relative;
|
||||
@@ -710,9 +712,9 @@ LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
|
||||
.title {
|
||||
text-align: center;
|
||||
color: var(--text-primary, #0D0D0D);
|
||||
font-size: 28px;
|
||||
font-size: 32px;
|
||||
font-weight: 400;
|
||||
line-height: 36.40px;
|
||||
line-height: 40px;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
.setup-box {
|
||||
@@ -785,16 +787,26 @@ LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
|
||||
word-wrap: break-word;
|
||||
text-decoration: none;
|
||||
}
|
||||
.logo {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 4rem;
|
||||
height: 4rem;
|
||||
border-radius: 16px;
|
||||
border: .5px solid rgba(0, 0, 0, 0.1);
|
||||
box-shadow: rgba(0, 0, 0, 0.1) 0px 4px 16px 0px;
|
||||
box-sizing: border-box;
|
||||
background-color: rgb(255, 255, 255);
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="inner-container">
|
||||
<div class="content">
|
||||
<div data-svg-wrapper class="svg-wrapper">
|
||||
<svg width="56" height="56" viewBox="0 0 56 56" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4.6665 28.0003C4.6665 15.1137 15.1132 4.66699 27.9998 4.66699C40.8865 4.66699 51.3332 15.1137 51.3332 28.0003C51.3332 40.887 40.8865 51.3337 27.9998 51.3337C15.1132 51.3337 4.6665 40.887 4.6665 28.0003ZM37.5093 18.5088C36.4554 17.7672 34.9999 18.0203 34.2583 19.0742L24.8508 32.4427L20.9764 28.1808C20.1095 27.2272 18.6338 27.1569 17.6803 28.0238C16.7267 28.8906 16.6565 30.3664 17.5233 31.3199L23.3566 37.7366C23.833 38.2606 24.5216 38.5399 25.2284 38.4958C25.9353 38.4517 26.5838 38.089 26.9914 37.5098L38.0747 21.7598C38.8163 20.7059 38.5632 19.2504 37.5093 18.5088Z" fill="var(--green-400, #04B84C)"/>
|
||||
</svg>
|
||||
<div class="logo">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" fill="none" viewBox="0 0 32 32"><path stroke="#000" stroke-linecap="round" stroke-width="2.484" d="M22.356 19.797H17.17M9.662 12.29l1.979 3.576a.511.511 0 0 1-.005.504l-1.974 3.409M30.758 16c0 8.15-6.607 14.758-14.758 14.758-8.15 0-14.758-6.607-14.758-14.758C1.242 7.85 7.85 1.242 16 1.242c8.15 0 14.758 6.608 14.758 14.758Z"></path></svg>
|
||||
</div>
|
||||
<div class="title">Signed in to Codex CLI</div>
|
||||
</div>
|
||||
|
||||
@@ -258,6 +258,7 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
|
||||
121
codex-rs/mcp-server/src/conversation_loop.rs
Normal file
121
codex-rs/mcp-server/src/conversation_loop.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::OutgoingNotificationMeta;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use mcp_types::RequestId;
|
||||
use tracing::error;
|
||||
|
||||
pub async fn run_conversation_loop(
|
||||
codex: Arc<Codex>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
) {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
|
||||
// Stream events until the task needs to pause for user interaction or
|
||||
// completes.
|
||||
loop {
|
||||
match codex.next_event().await {
|
||||
Ok(event) => {
|
||||
outgoing
|
||||
.send_event_as_notification(
|
||||
&event,
|
||||
Some(OutgoingNotificationMeta::new(Some(request_id.clone()))),
|
||||
)
|
||||
.await;
|
||||
|
||||
match event.msg {
|
||||
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
command,
|
||||
cwd,
|
||||
call_id,
|
||||
reason: _,
|
||||
}) => {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::Error(_) => {
|
||||
error!("Codex runtime error");
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::TaskComplete(_) => {}
|
||||
EventMsg::SessionConfigured(_) => {
|
||||
tracing::error!("unexpected SessionConfigured event");
|
||||
}
|
||||
EventMsg::AgentMessageDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentReasoningDelta(_) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||
// TODO: think how we want to support this in the MCP
|
||||
}
|
||||
EventMsg::TaskStarted
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
// For now, we do not do anything extra for these
|
||||
// events. Note that
|
||||
// send(codex_event_to_notification(&event)) above has
|
||||
// already dispatched these events as notifications,
|
||||
// though we may want to do give different treatment to
|
||||
// individual events in the future.
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Codex runtime error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ use tracing_subscriber::EnvFilter;
|
||||
|
||||
mod codex_tool_config;
|
||||
mod codex_tool_runner;
|
||||
mod conversation_loop;
|
||||
mod exec_approval;
|
||||
mod json_to_toml;
|
||||
pub mod mcp_protocol;
|
||||
|
||||
@@ -172,15 +172,22 @@ pub enum ToolCallResponseResult {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ConversationCreateResult {
|
||||
pub conversation_id: ConversationId,
|
||||
pub model: String,
|
||||
#[serde(untagged)]
|
||||
pub enum ConversationCreateResult {
|
||||
Ok {
|
||||
conversation_id: ConversationId,
|
||||
model: String,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ConversationStreamResult {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
// TODO: remove this status because we have is_error field in the response.
|
||||
#[serde(tag = "status", rename_all = "camelCase")]
|
||||
pub enum ConversationSendMessageResult {
|
||||
Ok,
|
||||
@@ -491,7 +498,7 @@ mod tests {
|
||||
request_id: RequestId::Integer(1),
|
||||
is_error: None,
|
||||
result: Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult {
|
||||
ConversationCreateResult::Ok {
|
||||
conversation_id: ConversationId(uuid!("d0f6ecbe-84a2-41c1-b23d-b20473b25eab")),
|
||||
model: "o3".into(),
|
||||
},
|
||||
@@ -515,6 +522,35 @@ mod tests {
|
||||
assert_eq!(req_id, RequestId::Integer(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_error_conversation_create_full_schema() {
|
||||
let env = ToolCallResponse {
|
||||
request_id: RequestId::Integer(2),
|
||||
is_error: Some(true),
|
||||
result: Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Error {
|
||||
message: "Failed to initialize session".into(),
|
||||
},
|
||||
)),
|
||||
};
|
||||
let req_id = env.request_id.clone();
|
||||
let observed = to_val(&CallToolResult::from(env));
|
||||
let expected = json!({
|
||||
"content": [
|
||||
{ "type": "text", "text": "{\"message\":\"Failed to initialize session\"}" }
|
||||
],
|
||||
"isError": true,
|
||||
"structuredContent": {
|
||||
"message": "Failed to initialize session"
|
||||
}
|
||||
});
|
||||
assert_eq!(
|
||||
observed, expected,
|
||||
"error response (ConversationCreate) must match"
|
||||
);
|
||||
assert_eq!(req_id, RequestId::Integer(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_success_conversation_stream_empty_result_object() {
|
||||
let env = ToolCallResponse {
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::mcp_protocol::ToolCallRequestParams;
|
||||
use crate::mcp_protocol::ToolCallResponse;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::tool_handlers::create_conversation::handle_create_conversation;
|
||||
use crate::tool_handlers::send_message::handle_send_message;
|
||||
|
||||
use codex_core::Codex;
|
||||
@@ -67,6 +68,10 @@ impl MessageProcessor {
|
||||
self.session_map.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn outgoing(&self) -> Arc<OutgoingMessageSender> {
|
||||
self.outgoing.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn running_session_ids(&self) -> Arc<Mutex<HashSet<Uuid>>> {
|
||||
self.running_session_ids.clone()
|
||||
}
|
||||
@@ -349,6 +354,9 @@ impl MessageProcessor {
|
||||
}
|
||||
async fn handle_new_tool_calls(&self, request_id: RequestId, params: ToolCallRequestParams) {
|
||||
match params {
|
||||
ToolCallRequestParams::ConversationCreate(args) => {
|
||||
handle_create_conversation(self, request_id, args).await;
|
||||
}
|
||||
ToolCallRequestParams::ConversationSendMessage(args) => {
|
||||
handle_send_message(self, request_id, args).await;
|
||||
}
|
||||
|
||||
160
codex-rs/mcp-server/src/tool_handlers/create_conversation.rs
Normal file
160
codex-rs/mcp-server/src/tool_handlers/create_conversation.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use mcp_types::RequestId;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::conversation_loop::run_conversation_loop;
|
||||
use crate::json_to_toml::json_to_toml;
|
||||
use crate::mcp_protocol::ConversationCreateArgs;
|
||||
use crate::mcp_protocol::ConversationCreateResult;
|
||||
use crate::mcp_protocol::ConversationId;
|
||||
use crate::mcp_protocol::ToolCallResponseResult;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
|
||||
pub(crate) async fn handle_create_conversation(
|
||||
message_processor: &MessageProcessor,
|
||||
id: RequestId,
|
||||
args: ConversationCreateArgs,
|
||||
) {
|
||||
// Build ConfigOverrides from args
|
||||
let ConversationCreateArgs {
|
||||
prompt: _, // not used here; creation only establishes the session
|
||||
model,
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox,
|
||||
config,
|
||||
profile,
|
||||
base_instructions,
|
||||
} = args;
|
||||
|
||||
// Convert config overrides JSON into CLI-style TOML overrides
|
||||
let cli_overrides: Vec<(String, toml::Value)> = match config {
|
||||
Some(v) => match v.as_object() {
|
||||
Some(map) => map
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.clone(), json_to_toml(v.clone())))
|
||||
.collect(),
|
||||
None => Vec::new(),
|
||||
},
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
let overrides = ConfigOverrides {
|
||||
model: Some(model.clone()),
|
||||
cwd: Some(PathBuf::from(cwd)),
|
||||
approval_policy,
|
||||
sandbox_mode: sandbox,
|
||||
model_provider: None,
|
||||
config_profile: profile,
|
||||
codex_linux_sandbox_exe: None,
|
||||
base_instructions,
|
||||
include_plan_tool: None,
|
||||
};
|
||||
|
||||
let cfg: CodexConfig = match CodexConfig::load_with_cli_overrides(cli_overrides, overrides) {
|
||||
Ok(cfg) => cfg,
|
||||
Err(e) => {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Error {
|
||||
message: format!("Failed to load config: {e}"),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize Codex session
|
||||
let codex_conversation = match init_codex(cfg).await {
|
||||
Ok(conv) => conv,
|
||||
Err(e) => {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Error {
|
||||
message: format!("Failed to initialize session: {e}"),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Expect SessionConfigured; if not, return error.
|
||||
let EventMsg::SessionConfigured(SessionConfiguredEvent { model, .. }) =
|
||||
&codex_conversation.session_configured.msg
|
||||
else {
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Error {
|
||||
message: "Expected SessionConfigured event".to_string(),
|
||||
},
|
||||
)),
|
||||
Some(true),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
let effective_model = model.clone();
|
||||
|
||||
let session_id = codex_conversation.session_id;
|
||||
let codex_arc = Arc::new(codex_conversation.codex);
|
||||
|
||||
// Store session for future calls
|
||||
insert_session(
|
||||
session_id,
|
||||
codex_arc.clone(),
|
||||
message_processor.session_map(),
|
||||
)
|
||||
.await;
|
||||
// Run the conversation loop in the background so this request can return immediately.
|
||||
let outgoing = message_processor.outgoing();
|
||||
let spawn_id = id.clone();
|
||||
tokio::spawn(async move {
|
||||
run_conversation_loop(codex_arc.clone(), outgoing, spawn_id).await;
|
||||
});
|
||||
|
||||
// Reply with the new conversation id and effective model
|
||||
message_processor
|
||||
.send_response_with_optional_error(
|
||||
id,
|
||||
Some(ToolCallResponseResult::ConversationCreate(
|
||||
ConversationCreateResult::Ok {
|
||||
conversation_id: ConversationId(session_id),
|
||||
model: effective_model,
|
||||
},
|
||||
)),
|
||||
Some(false),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn insert_session(
|
||||
session_id: Uuid,
|
||||
codex: Arc<Codex>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
) {
|
||||
let mut guard = session_map.lock().await;
|
||||
guard.insert(session_id, codex);
|
||||
}
|
||||
@@ -1 +1,2 @@
|
||||
pub(crate) mod create_conversation;
|
||||
pub(crate) mod send_message;
|
||||
|
||||
@@ -14,6 +14,7 @@ use assert_cmd::prelude::*;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_mcp_server::CodexToolCallParam;
|
||||
use codex_mcp_server::CodexToolCallReplyParam;
|
||||
use codex_mcp_server::mcp_protocol::ConversationCreateArgs;
|
||||
use codex_mcp_server::mcp_protocol::ConversationId;
|
||||
use codex_mcp_server::mcp_protocol::ConversationSendMessageArgs;
|
||||
use codex_mcp_server::mcp_protocol::ToolCallRequestParams;
|
||||
@@ -200,6 +201,41 @@ impl McpProcess {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_create_tool_call(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
cwd: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = ToolCallRequestParams::ConversationCreate(ConversationCreateArgs {
|
||||
prompt: prompt.to_string(),
|
||||
model: model.to_string(),
|
||||
cwd: cwd.to_string(),
|
||||
approval_policy: None,
|
||||
sandbox: None,
|
||||
config: None,
|
||||
profile: None,
|
||||
base_instructions: None,
|
||||
});
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn send_conversation_create_with_args(
|
||||
&mut self,
|
||||
args: ConversationCreateArgs,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = ToolCallRequestParams::ConversationCreate(args);
|
||||
self.send_request(
|
||||
mcp_types::CallToolRequest::METHOD,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
|
||||
128
codex-rs/mcp-server/tests/create_conversation.rs
Normal file
128
codex-rs/mcp-server/tests/create_conversation.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use mcp_test_support::McpProcess;
|
||||
use mcp_test_support::create_final_assistant_message_sse_response;
|
||||
use mcp_test_support::create_mock_chat_completions_server;
|
||||
use mcp_types::JSONRPCResponse;
|
||||
use mcp_types::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_conversation_create_and_send_message_ok() {
|
||||
// Mock server – we won't strictly rely on it, but provide one to satisfy any model wiring.
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
|
||||
];
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
|
||||
// Temporary Codex home with config pointing at the mock server.
|
||||
let codex_home = TempDir::new().expect("create temp dir");
|
||||
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
|
||||
|
||||
// Start MCP server process and initialize.
|
||||
let mut mcp = McpProcess::new(codex_home.path())
|
||||
.await
|
||||
.expect("spawn mcp process");
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||
.await
|
||||
.expect("init timeout")
|
||||
.expect("init failed");
|
||||
|
||||
// Create a conversation via the new tool.
|
||||
let req_id = mcp
|
||||
.send_conversation_create_tool_call("", "o3", "/repo")
|
||||
.await
|
||||
.expect("send conversationCreate");
|
||||
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(req_id)),
|
||||
)
|
||||
.await
|
||||
.expect("create response timeout")
|
||||
.expect("create response error");
|
||||
|
||||
// Structured content must include status=ok, a UUID conversation_id and the model we passed.
|
||||
let sc = &resp.result["structuredContent"];
|
||||
let conv_id = sc["conversation_id"].as_str().expect("uuid string");
|
||||
assert!(!conv_id.is_empty());
|
||||
assert_eq!(sc["model"], json!("o3"));
|
||||
|
||||
// Now send a message to the created conversation and expect an OK result.
|
||||
let send_id = mcp
|
||||
.send_user_message_tool_call("Hello", conv_id)
|
||||
.await
|
||||
.expect("send message");
|
||||
|
||||
let send_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(send_id)),
|
||||
)
|
||||
.await
|
||||
.expect("send response timeout")
|
||||
.expect("send response error");
|
||||
assert_eq!(
|
||||
send_resp.result["structuredContent"],
|
||||
json!({ "status": "ok" })
|
||||
);
|
||||
|
||||
// avoid race condition by waiting for the mock server to receive the chat.completions request
|
||||
let deadline = std::time::Instant::now() + DEFAULT_READ_TIMEOUT;
|
||||
loop {
|
||||
let requests = server.received_requests().await.unwrap_or_default();
|
||||
if !requests.is_empty() {
|
||||
break;
|
||||
}
|
||||
if std::time::Instant::now() >= deadline {
|
||||
panic!("mock server did not receive the chat.completions request in time");
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
|
||||
// Verify the outbound request body matches expectations for Chat Completions.
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let body = request
|
||||
.body_json::<serde_json::Value>()
|
||||
.expect("parse request body as JSON");
|
||||
assert_eq!(body["model"], json!("o3"));
|
||||
assert!(body["stream"].as_bool().unwrap_or(false));
|
||||
let messages = body["messages"]
|
||||
.as_array()
|
||||
.expect("messages should be array");
|
||||
let last = messages.last().expect("at least one message");
|
||||
assert_eq!(last["role"], json!("user"));
|
||||
assert_eq!(last["content"], json!("Hello"));
|
||||
|
||||
drop(server);
|
||||
}
|
||||
|
||||
// Helper to create a config.toml pointing at the mock model server.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "danger-full-access"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -17,6 +17,7 @@ workspace = true
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
base64 = "0.22.1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-ansi-escape = { path = "../ansi-escape" }
|
||||
codex-arg0 = { path = "../arg0" }
|
||||
@@ -41,6 +42,8 @@ ratatui = { version = "0.29.0", features = [
|
||||
] }
|
||||
ratatui-image = "8.0.0"
|
||||
regex-lite = "0.1"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = { version = "1", features = ["preserve_order"] }
|
||||
shlex = "1.3.0"
|
||||
strum = "0.27.2"
|
||||
|
||||
@@ -43,7 +43,6 @@ pub(crate) struct ChatComposer<'a> {
|
||||
ctrl_c_quit_hint: bool,
|
||||
use_shift_enter_hint: bool,
|
||||
dismissed_file_popup_token: Option<String>,
|
||||
dismissed_slash_token: Option<String>,
|
||||
current_file_query: Option<String>,
|
||||
pending_pastes: Vec<(String, String)>,
|
||||
}
|
||||
@@ -56,95 +55,6 @@ enum ActivePopup {
|
||||
}
|
||||
|
||||
impl ChatComposer<'_> {
|
||||
#[inline]
|
||||
fn first_line(&self) -> &str {
|
||||
self.textarea
|
||||
.lines()
|
||||
.first()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn slash_token_from_first_line(first_line: &str) -> Option<&str> {
|
||||
if !first_line.starts_with('/') {
|
||||
return None;
|
||||
}
|
||||
let stripped = first_line.strip_prefix('/').unwrap_or("");
|
||||
let token = stripped.trim_start();
|
||||
Some(token.split_whitespace().next().unwrap_or(""))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn emit_unrecognized_slash_command(&mut self, cmd_token: &str) {
|
||||
let attempted = if cmd_token.is_empty() {
|
||||
"/".to_string()
|
||||
} else {
|
||||
format!("/{cmd_token}")
|
||||
};
|
||||
let msg = format!("{attempted} not a recognized command");
|
||||
self.app_event_tx
|
||||
.send(AppEvent::InsertHistory(vec![Line::from(msg)]));
|
||||
self.dismissed_slash_token = Some(cmd_token.to_string());
|
||||
self.active_popup = ActivePopup::None;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn sync_popups(&mut self) {
|
||||
self.sync_command_popup();
|
||||
if matches!(self.active_popup, ActivePopup::Command(_)) {
|
||||
self.dismissed_file_popup_token = None;
|
||||
} else {
|
||||
self.sync_file_search_popup();
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn compute_textarea_and_popup_rect(&self, area: Rect, desired_popup: u16) -> (Rect, Rect) {
|
||||
let text_lines = self.textarea.lines().len().max(1) as u16;
|
||||
let popup_height = desired_popup.min(area.height.saturating_sub(text_lines));
|
||||
let textarea_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y,
|
||||
width: area.width,
|
||||
height: area.height.saturating_sub(popup_height),
|
||||
};
|
||||
let popup_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y + textarea_rect.height,
|
||||
width: area.width,
|
||||
height: popup_height,
|
||||
};
|
||||
(textarea_rect, popup_rect)
|
||||
}
|
||||
|
||||
/// Convert a cursor column (in chars) to a byte offset within `line`.
|
||||
#[inline]
|
||||
fn cursor_byte_offset_for_col(line: &str, col_chars: usize) -> usize {
|
||||
line.chars().take(col_chars).map(|c| c.len_utf8()).sum()
|
||||
}
|
||||
|
||||
/// Determine token boundaries around a cursor position expressed as a byte offset.
|
||||
/// A token is delimited by any Unicode whitespace on either side.
|
||||
#[inline]
|
||||
fn token_bounds(line: &str, cursor_byte_offset: usize) -> Option<(usize, usize)> {
|
||||
let (before_cursor, after_cursor) = line.split_at(cursor_byte_offset.min(line.len()));
|
||||
|
||||
let start_idx = before_cursor
|
||||
.char_indices()
|
||||
.rfind(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, c)| idx + c.len_utf8())
|
||||
.unwrap_or(0);
|
||||
|
||||
let end_rel_idx = after_cursor
|
||||
.char_indices()
|
||||
.find(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(after_cursor.len());
|
||||
let end_idx = cursor_byte_offset + end_rel_idx;
|
||||
|
||||
(start_idx < end_idx).then_some((start_idx, end_idx))
|
||||
}
|
||||
pub fn new(
|
||||
has_input_focus: bool,
|
||||
app_event_tx: AppEventSender,
|
||||
@@ -164,7 +74,6 @@ impl ChatComposer<'_> {
|
||||
ctrl_c_quit_hint: false,
|
||||
use_shift_enter_hint,
|
||||
dismissed_file_popup_token: None,
|
||||
dismissed_slash_token: None,
|
||||
current_file_query: None,
|
||||
pending_pastes: Vec::new(),
|
||||
};
|
||||
@@ -246,7 +155,8 @@ impl ChatComposer<'_> {
|
||||
} else {
|
||||
self.textarea.insert_str(&pasted);
|
||||
}
|
||||
self.sync_popups();
|
||||
self.sync_command_popup();
|
||||
self.sync_file_search_popup();
|
||||
true
|
||||
}
|
||||
|
||||
@@ -280,14 +190,19 @@ impl ChatComposer<'_> {
|
||||
ActivePopup::None => self.handle_key_event_without_popup(key_event),
|
||||
};
|
||||
|
||||
self.sync_popups();
|
||||
// Update (or hide/show) popup after processing the key.
|
||||
self.sync_command_popup();
|
||||
if matches!(self.active_popup, ActivePopup::Command(_)) {
|
||||
self.dismissed_file_popup_token = None;
|
||||
} else {
|
||||
self.sync_file_search_popup();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Handle key event when the slash-command popup is visible.
|
||||
fn handle_key_event_with_slash_popup(&mut self, key_event: KeyEvent) -> (InputResult, bool) {
|
||||
let first_line_owned = self.first_line().to_string();
|
||||
let ActivePopup::Command(popup) = &mut self.active_popup else {
|
||||
unreachable!();
|
||||
};
|
||||
@@ -301,16 +216,16 @@ impl ChatComposer<'_> {
|
||||
popup.move_down();
|
||||
(InputResult::None, true)
|
||||
}
|
||||
Input { key: Key::Esc, .. } => {
|
||||
if let Some(cmd_token) = Self::slash_token_from_first_line(&first_line_owned) {
|
||||
self.dismissed_slash_token = Some(cmd_token.to_string());
|
||||
}
|
||||
self.active_popup = ActivePopup::None;
|
||||
(InputResult::None, true)
|
||||
}
|
||||
Input { key: Key::Tab, .. } => {
|
||||
if let Some(cmd) = popup.selected_command() {
|
||||
let starts_with_cmd = first_line_owned
|
||||
let first_line = self
|
||||
.textarea
|
||||
.lines()
|
||||
.first()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let starts_with_cmd = first_line
|
||||
.trim_start()
|
||||
.starts_with(&format!("/{}", cmd.command()));
|
||||
|
||||
@@ -329,17 +244,19 @@ impl ChatComposer<'_> {
|
||||
ctrl: false,
|
||||
} => {
|
||||
if let Some(cmd) = popup.selected_command() {
|
||||
// Send command to the app layer.
|
||||
self.app_event_tx.send(AppEvent::DispatchCommand(*cmd));
|
||||
|
||||
// Clear textarea so no residual text remains.
|
||||
self.textarea.select_all();
|
||||
self.textarea.cut();
|
||||
|
||||
// Hide popup since the command has been dispatched.
|
||||
self.active_popup = ActivePopup::None;
|
||||
return (InputResult::None, true);
|
||||
}
|
||||
if let Some(cmd_token) = Self::slash_token_from_first_line(&first_line_owned) {
|
||||
self.emit_unrecognized_slash_command(cmd_token);
|
||||
return (InputResult::None, true);
|
||||
}
|
||||
(InputResult::None, false)
|
||||
// Fallback to default newline handling if no command selected.
|
||||
self.handle_key_event_without_popup(key_event)
|
||||
}
|
||||
input => self.handle_input_basic(input),
|
||||
}
|
||||
@@ -400,9 +317,37 @@ impl ChatComposer<'_> {
|
||||
/// one additional character, that token (without `@`) is returned.
|
||||
fn current_at_token(textarea: &tui_textarea::TextArea) -> Option<String> {
|
||||
let (row, col) = textarea.cursor();
|
||||
|
||||
// Guard against out-of-bounds rows.
|
||||
let line = textarea.lines().get(row)?.as_str();
|
||||
let cursor_byte_offset = Self::cursor_byte_offset_for_col(line, col);
|
||||
let (start_idx, end_idx) = Self::token_bounds(line, cursor_byte_offset)?;
|
||||
|
||||
// Calculate byte offset for cursor position
|
||||
let cursor_byte_offset = line.chars().take(col).map(|c| c.len_utf8()).sum::<usize>();
|
||||
|
||||
// Split the line at the cursor position so we can search for word
|
||||
// boundaries on both sides.
|
||||
let before_cursor = &line[..cursor_byte_offset];
|
||||
let after_cursor = &line[cursor_byte_offset..];
|
||||
|
||||
// Find start index (first character **after** the previous multi-byte whitespace).
|
||||
let start_idx = before_cursor
|
||||
.char_indices()
|
||||
.rfind(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, c)| idx + c.len_utf8())
|
||||
.unwrap_or(0);
|
||||
|
||||
// Find end index (first multi-byte whitespace **after** the cursor position).
|
||||
let end_rel_idx = after_cursor
|
||||
.char_indices()
|
||||
.find(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(after_cursor.len());
|
||||
let end_idx = cursor_byte_offset + end_rel_idx;
|
||||
|
||||
if start_idx >= end_idx {
|
||||
return None;
|
||||
}
|
||||
|
||||
let token = &line[start_idx..end_idx];
|
||||
|
||||
if token.starts_with('@') && token.len() > 1 {
|
||||
@@ -419,24 +364,50 @@ impl ChatComposer<'_> {
|
||||
/// `@tokens` exist in the line.
|
||||
fn insert_selected_path(&mut self, path: &str) {
|
||||
let (row, col) = self.textarea.cursor();
|
||||
|
||||
// Materialize the textarea lines so we can mutate them easily.
|
||||
let mut lines: Vec<String> = self.textarea.lines().to_vec();
|
||||
|
||||
if let Some(line) = lines.get_mut(row) {
|
||||
let cursor_byte_offset = Self::cursor_byte_offset_for_col(line, col);
|
||||
if let Some((start_idx, end_idx)) = Self::token_bounds(line, cursor_byte_offset) {
|
||||
let mut new_line =
|
||||
String::with_capacity(line.len() - (end_idx - start_idx) + path.len() + 1);
|
||||
new_line.push_str(&line[..start_idx]);
|
||||
new_line.push_str(path);
|
||||
new_line.push(' ');
|
||||
new_line.push_str(&line[end_idx..]);
|
||||
*line = new_line;
|
||||
// Calculate byte offset for cursor position
|
||||
let cursor_byte_offset = line.chars().take(col).map(|c| c.len_utf8()).sum::<usize>();
|
||||
|
||||
let new_text = lines.join("\n");
|
||||
self.textarea.select_all();
|
||||
self.textarea.cut();
|
||||
let _ = self.textarea.insert_str(new_text);
|
||||
}
|
||||
let before_cursor = &line[..cursor_byte_offset];
|
||||
let after_cursor = &line[cursor_byte_offset..];
|
||||
|
||||
// Determine token boundaries.
|
||||
let start_idx = before_cursor
|
||||
.char_indices()
|
||||
.rfind(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, c)| idx + c.len_utf8())
|
||||
.unwrap_or(0);
|
||||
|
||||
let end_rel_idx = after_cursor
|
||||
.char_indices()
|
||||
.find(|(_, c)| c.is_whitespace())
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(after_cursor.len());
|
||||
let end_idx = cursor_byte_offset + end_rel_idx;
|
||||
|
||||
// Replace the slice `[start_idx, end_idx)` with the chosen path and a trailing space.
|
||||
let mut new_line =
|
||||
String::with_capacity(line.len() - (end_idx - start_idx) + path.len() + 1);
|
||||
new_line.push_str(&line[..start_idx]);
|
||||
new_line.push_str(path);
|
||||
new_line.push(' ');
|
||||
new_line.push_str(&line[end_idx..]);
|
||||
|
||||
*line = new_line;
|
||||
|
||||
// Re-populate the textarea.
|
||||
let new_text = lines.join("\n");
|
||||
self.textarea.select_all();
|
||||
self.textarea.cut();
|
||||
let _ = self.textarea.insert_str(new_text);
|
||||
|
||||
// Note: tui-textarea currently exposes only relative cursor
|
||||
// movements. Leaving the cursor position unchanged is acceptable
|
||||
// as subsequent typing will move the cursor naturally.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -608,28 +579,29 @@ impl ChatComposer<'_> {
|
||||
/// textarea. This must be called after every modification that can change
|
||||
/// the text so the popup is shown/updated/hidden as appropriate.
|
||||
fn sync_command_popup(&mut self) {
|
||||
let first_line = self.first_line().to_string();
|
||||
// Inspect only the first line to decide whether to show the popup. In
|
||||
// the common case (no leading slash) we avoid copying the entire
|
||||
// textarea contents.
|
||||
let first_line = self
|
||||
.textarea
|
||||
.lines()
|
||||
.first()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let input_starts_with_slash = first_line.starts_with('/');
|
||||
if !input_starts_with_slash {
|
||||
self.dismissed_slash_token = None;
|
||||
}
|
||||
let current_cmd_token: Option<&str> = Self::slash_token_from_first_line(&first_line);
|
||||
match &mut self.active_popup {
|
||||
ActivePopup::Command(popup) => {
|
||||
if input_starts_with_slash {
|
||||
popup.on_composer_text_change(first_line.clone());
|
||||
popup.on_composer_text_change(first_line.to_string());
|
||||
} else {
|
||||
self.active_popup = ActivePopup::None;
|
||||
self.dismissed_slash_token = None;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if input_starts_with_slash {
|
||||
if self.dismissed_slash_token.as_deref() == current_cmd_token {
|
||||
return;
|
||||
}
|
||||
let mut command_popup = CommandPopup::new();
|
||||
command_popup.on_composer_text_change(first_line);
|
||||
command_popup.on_composer_text_change(first_line.to_string());
|
||||
self.active_popup = ActivePopup::Command(command_popup);
|
||||
}
|
||||
}
|
||||
@@ -692,16 +664,44 @@ impl WidgetRef for &ChatComposer<'_> {
|
||||
fn render_ref(&self, area: Rect, buf: &mut Buffer) {
|
||||
match &self.active_popup {
|
||||
ActivePopup::Command(popup) => {
|
||||
let desired_popup = popup.calculate_required_height();
|
||||
let (textarea_rect, popup_rect) =
|
||||
self.compute_textarea_and_popup_rect(area, desired_popup);
|
||||
let popup_height = popup.calculate_required_height();
|
||||
|
||||
// Split the provided rect so that the popup is rendered at the
|
||||
// **bottom** and the textarea occupies the remaining space above.
|
||||
let popup_height = popup_height.min(area.height);
|
||||
let textarea_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y,
|
||||
width: area.width,
|
||||
height: area.height.saturating_sub(popup_height),
|
||||
};
|
||||
let popup_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y + textarea_rect.height,
|
||||
width: area.width,
|
||||
height: popup_height,
|
||||
};
|
||||
|
||||
popup.render(popup_rect, buf);
|
||||
self.textarea.render(textarea_rect, buf);
|
||||
}
|
||||
ActivePopup::File(popup) => {
|
||||
let desired_popup = popup.calculate_required_height();
|
||||
let (textarea_rect, popup_rect) =
|
||||
self.compute_textarea_and_popup_rect(area, desired_popup);
|
||||
let popup_height = popup.calculate_required_height();
|
||||
|
||||
let popup_height = popup_height.min(area.height);
|
||||
let textarea_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y,
|
||||
width: area.width,
|
||||
height: area.height.saturating_sub(popup_height),
|
||||
};
|
||||
let popup_rect = Rect {
|
||||
x: area.x,
|
||||
y: area.y + textarea_rect.height,
|
||||
width: area.width,
|
||||
height: popup_height,
|
||||
};
|
||||
|
||||
popup.render(popup_rect, buf);
|
||||
self.textarea.render(textarea_rect, buf);
|
||||
}
|
||||
@@ -745,7 +745,6 @@ impl WidgetRef for &ChatComposer<'_> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::bottom_pane::AppEventSender;
|
||||
use crate::bottom_pane::ChatComposer;
|
||||
use crate::bottom_pane::InputResult;
|
||||
@@ -1021,97 +1020,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esc_dismiss_slash_popup_reopen_on_token_change() {
|
||||
use crate::bottom_pane::chat_composer::ActivePopup;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyModifiers;
|
||||
|
||||
let (tx, _rx) = std::sync::mpsc::channel();
|
||||
let sender = AppEventSender::new(tx);
|
||||
let mut composer = ChatComposer::new(true, sender, false);
|
||||
|
||||
composer.handle_paste("/".to_string());
|
||||
assert!(matches!(composer.active_popup, ActivePopup::Command(_)));
|
||||
|
||||
let _ = composer.handle_key_event(KeyEvent::new(KeyCode::Esc, KeyModifiers::NONE));
|
||||
assert!(matches!(composer.active_popup, ActivePopup::None));
|
||||
|
||||
composer.handle_paste("c".to_string());
|
||||
assert!(matches!(composer.active_popup, ActivePopup::Command(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enter_with_unrecognized_slash_command_closes_popup_and_emits_error() {
|
||||
use crate::app_event::AppEvent;
|
||||
use crate::bottom_pane::chat_composer::ActivePopup;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyModifiers;
|
||||
use std::sync::mpsc::TryRecvError;
|
||||
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let sender = AppEventSender::new(tx);
|
||||
let mut composer = ChatComposer::new(true, sender, false);
|
||||
|
||||
composer.handle_paste("/ notacommand args".to_string());
|
||||
assert!(matches!(composer.active_popup, ActivePopup::Command(_)));
|
||||
|
||||
let _ = composer.handle_key_event(KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE));
|
||||
|
||||
assert!(matches!(composer.active_popup, ActivePopup::None));
|
||||
|
||||
let mut saw_error = false;
|
||||
loop {
|
||||
match rx.try_recv() {
|
||||
Ok(AppEvent::InsertHistory(lines)) => {
|
||||
let text = lines
|
||||
.into_iter()
|
||||
.map(|l| l.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
if text.contains("/notacommand not a recognized command") {
|
||||
saw_error = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(TryRecvError::Empty) => break,
|
||||
Err(TryRecvError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_error,
|
||||
"expected InsertHistory error for unrecognized command"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn esc_dismiss_then_delete_and_retype_slash_reopens_popup() {
|
||||
use crate::bottom_pane::chat_composer::ActivePopup;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyModifiers;
|
||||
|
||||
let (tx, _rx) = std::sync::mpsc::channel();
|
||||
let sender = AppEventSender::new(tx);
|
||||
let mut composer = ChatComposer::new(true, sender, false);
|
||||
|
||||
composer.handle_paste("/".to_string());
|
||||
assert!(matches!(composer.active_popup, ActivePopup::Command(_)));
|
||||
|
||||
let _ = composer.handle_key_event(KeyEvent::new(KeyCode::Esc, KeyModifiers::NONE));
|
||||
assert!(matches!(composer.active_popup, ActivePopup::None));
|
||||
|
||||
let _ = composer.handle_key_event(KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE));
|
||||
assert!(matches!(composer.active_popup, ActivePopup::None));
|
||||
|
||||
composer.handle_paste("/".to_string());
|
||||
assert!(matches!(composer.active_popup, ActivePopup::Command(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_pastes_submission() {
|
||||
use crossterm::event::KeyCode;
|
||||
@@ -1122,15 +1030,18 @@ mod tests {
|
||||
let sender = AppEventSender::new(tx);
|
||||
let mut composer = ChatComposer::new(true, sender, false);
|
||||
|
||||
// Define test cases: (paste content, is_large)
|
||||
let test_cases = [
|
||||
("x".repeat(LARGE_PASTE_CHAR_THRESHOLD + 3), true),
|
||||
(" and ".to_string(), false),
|
||||
("y".repeat(LARGE_PASTE_CHAR_THRESHOLD + 7), true),
|
||||
];
|
||||
|
||||
// Expected states after each paste
|
||||
let mut expected_text = String::new();
|
||||
let mut expected_pending_count = 0;
|
||||
|
||||
// Apply all pastes and build expected state
|
||||
let states: Vec<_> = test_cases
|
||||
.iter()
|
||||
.map(|(content, is_large)| {
|
||||
@@ -1146,6 +1057,7 @@ mod tests {
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Verify all intermediate states were correct
|
||||
assert_eq!(
|
||||
states,
|
||||
vec![
|
||||
@@ -1171,6 +1083,7 @@ mod tests {
|
||||
]
|
||||
);
|
||||
|
||||
// Submit and verify final expansion
|
||||
let (result, _) =
|
||||
composer.handle_key_event(KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE));
|
||||
if let InputResult::Submitted(text) = result {
|
||||
@@ -1190,12 +1103,14 @@ mod tests {
|
||||
let sender = AppEventSender::new(tx);
|
||||
let mut composer = ChatComposer::new(true, sender, false);
|
||||
|
||||
// Define test cases: (content, is_large)
|
||||
let test_cases = [
|
||||
("a".repeat(LARGE_PASTE_CHAR_THRESHOLD + 5), true),
|
||||
(" and ".to_string(), false),
|
||||
("b".repeat(LARGE_PASTE_CHAR_THRESHOLD + 6), true),
|
||||
];
|
||||
|
||||
// Apply all pastes
|
||||
let mut current_pos = 0;
|
||||
let states: Vec<_> = test_cases
|
||||
.iter()
|
||||
@@ -1215,8 +1130,10 @@ mod tests {
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Delete placeholders one by one and collect states
|
||||
let mut deletion_states = vec![];
|
||||
|
||||
// First deletion
|
||||
composer
|
||||
.textarea
|
||||
.move_cursor(tui_textarea::CursorMove::Jump(0, states[0].2 as u16));
|
||||
@@ -1226,6 +1143,7 @@ mod tests {
|
||||
composer.pending_pastes.len(),
|
||||
));
|
||||
|
||||
// Second deletion
|
||||
composer
|
||||
.textarea
|
||||
.move_cursor(tui_textarea::CursorMove::Jump(
|
||||
@@ -1238,6 +1156,7 @@ mod tests {
|
||||
composer.pending_pastes.len(),
|
||||
));
|
||||
|
||||
// Verify all states
|
||||
assert_eq!(
|
||||
deletion_states,
|
||||
vec![
|
||||
@@ -1257,6 +1176,7 @@ mod tests {
|
||||
let sender = AppEventSender::new(tx);
|
||||
let mut composer = ChatComposer::new(true, sender, false);
|
||||
|
||||
// Define test cases: (cursor_position_from_end, expected_pending_count)
|
||||
let test_cases = [
|
||||
5, // Delete from middle - should clear tracking
|
||||
0, // Delete from end - should clear tracking
|
||||
|
||||
@@ -25,8 +25,6 @@ pub(crate) struct CommandPopup {
|
||||
command_filter: String,
|
||||
all_commands: Vec<(&'static str, SlashCommand)>,
|
||||
selected_idx: Option<usize>,
|
||||
// Index of the first visible row in the filtered list.
|
||||
scroll_top: usize,
|
||||
}
|
||||
|
||||
impl CommandPopup {
|
||||
@@ -35,7 +33,6 @@ impl CommandPopup {
|
||||
command_filter: String::new(),
|
||||
all_commands: built_in_slash_commands(),
|
||||
selected_idx: None,
|
||||
scroll_top: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,28 +66,26 @@ impl CommandPopup {
|
||||
0 => None,
|
||||
_ => Some(self.selected_idx.unwrap_or(0).min(matches_len - 1)),
|
||||
};
|
||||
|
||||
self.adjust_scroll(matches_len);
|
||||
}
|
||||
|
||||
/// Determine the preferred height of the popup. This is the number of
|
||||
/// rows required to show at most MAX_POPUP_ROWS commands.
|
||||
/// rows required to show **at most** `MAX_POPUP_ROWS` commands plus the
|
||||
/// table/border overhead (one line at the top and one at the bottom).
|
||||
pub(crate) fn calculate_required_height(&self) -> u16 {
|
||||
self.filtered_commands().len().clamp(1, MAX_POPUP_ROWS) as u16
|
||||
}
|
||||
|
||||
/// Return the list of commands that match the current filter. Matching is
|
||||
/// performed using a case-insensitive prefix comparison on the command name.
|
||||
/// performed using a *prefix* comparison on the command name.
|
||||
fn filtered_commands(&self) -> Vec<&SlashCommand> {
|
||||
let filter = self.command_filter.as_str();
|
||||
self.all_commands
|
||||
.iter()
|
||||
.filter_map(|(_name, cmd)| {
|
||||
if filter.is_empty() {
|
||||
return Some(cmd);
|
||||
}
|
||||
let name = cmd.command();
|
||||
if name.len() >= filter.len() && name[..filter.len()].eq_ignore_ascii_case(filter) {
|
||||
if self.command_filter.is_empty()
|
||||
|| cmd
|
||||
.command()
|
||||
.starts_with(&self.command_filter.to_ascii_lowercase())
|
||||
{
|
||||
Some(cmd)
|
||||
} else {
|
||||
None
|
||||
@@ -101,30 +96,26 @@ impl CommandPopup {
|
||||
|
||||
/// Move the selection cursor one step up.
|
||||
pub(crate) fn move_up(&mut self) {
|
||||
let matches = self.filtered_commands();
|
||||
let len = matches.len();
|
||||
if len == 0 {
|
||||
self.selected_idx = None;
|
||||
self.scroll_top = 0;
|
||||
return;
|
||||
if let Some(len) = self.filtered_commands().len().checked_sub(1) {
|
||||
if len == usize::MAX {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
match self.selected_idx {
|
||||
Some(idx) if idx > 0 => self.selected_idx = Some(idx - 1),
|
||||
Some(_) => self.selected_idx = Some(len - 1), // wrap to last
|
||||
None => self.selected_idx = Some(0),
|
||||
if let Some(idx) = self.selected_idx {
|
||||
if idx > 0 {
|
||||
self.selected_idx = Some(idx - 1);
|
||||
}
|
||||
} else if !self.filtered_commands().is_empty() {
|
||||
self.selected_idx = Some(0);
|
||||
}
|
||||
|
||||
self.adjust_scroll(len);
|
||||
}
|
||||
|
||||
/// Move the selection cursor one step down.
|
||||
pub(crate) fn move_down(&mut self) {
|
||||
let matches = self.filtered_commands();
|
||||
let matches_len = matches.len();
|
||||
let matches_len = self.filtered_commands().len();
|
||||
if matches_len == 0 {
|
||||
self.selected_idx = None;
|
||||
self.scroll_top = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -132,15 +123,11 @@ impl CommandPopup {
|
||||
Some(idx) if idx + 1 < matches_len => {
|
||||
self.selected_idx = Some(idx + 1);
|
||||
}
|
||||
Some(_idx_last) => {
|
||||
self.selected_idx = Some(0);
|
||||
}
|
||||
None => {
|
||||
self.selected_idx = Some(0);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
self.adjust_scroll(matches_len);
|
||||
}
|
||||
|
||||
/// Return currently selected command, if any.
|
||||
@@ -148,26 +135,6 @@ impl CommandPopup {
|
||||
let matches = self.filtered_commands();
|
||||
self.selected_idx.and_then(|idx| matches.get(idx).copied())
|
||||
}
|
||||
|
||||
fn adjust_scroll(&mut self, matches_len: usize) {
|
||||
if matches_len == 0 {
|
||||
self.scroll_top = 0;
|
||||
return;
|
||||
}
|
||||
let visible_rows = MAX_POPUP_ROWS.min(matches_len);
|
||||
if let Some(sel) = self.selected_idx {
|
||||
if sel < self.scroll_top {
|
||||
self.scroll_top = sel;
|
||||
} else {
|
||||
let bottom = self.scroll_top + visible_rows - 1;
|
||||
if sel > bottom {
|
||||
self.scroll_top = sel + 1 - visible_rows;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.scroll_top = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for CommandPopup {
|
||||
@@ -175,8 +142,10 @@ impl WidgetRef for CommandPopup {
|
||||
let matches = self.filtered_commands();
|
||||
|
||||
let mut rows: Vec<Row> = Vec::new();
|
||||
let visible_matches: Vec<&SlashCommand> =
|
||||
matches.into_iter().take(MAX_POPUP_ROWS).collect();
|
||||
|
||||
if matches.is_empty() {
|
||||
if visible_matches.is_empty() {
|
||||
rows.push(Row::new(vec![
|
||||
Cell::from(""),
|
||||
Cell::from("No matching commands").add_modifier(Modifier::ITALIC),
|
||||
@@ -184,32 +153,10 @@ impl WidgetRef for CommandPopup {
|
||||
} else {
|
||||
let default_style = Style::default();
|
||||
let command_style = Style::default().fg(Color::LightBlue);
|
||||
let max_rows_from_area = area.height as usize;
|
||||
let visible_rows = MAX_POPUP_ROWS
|
||||
.min(matches.len())
|
||||
.min(max_rows_from_area.max(1));
|
||||
|
||||
let mut start_idx = self.scroll_top.min(matches.len().saturating_sub(1));
|
||||
if let Some(sel) = self.selected_idx {
|
||||
if sel < start_idx {
|
||||
start_idx = sel;
|
||||
} else if visible_rows > 0 {
|
||||
let bottom = start_idx + visible_rows - 1;
|
||||
if sel > bottom {
|
||||
start_idx = sel + 1 - visible_rows;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (global_idx, cmd) in matches
|
||||
.iter()
|
||||
.enumerate()
|
||||
.skip(start_idx)
|
||||
.take(visible_rows)
|
||||
{
|
||||
for (idx, cmd) in visible_matches.iter().enumerate() {
|
||||
rows.push(Row::new(vec![
|
||||
Cell::from(Line::from(vec![
|
||||
if Some(global_idx) == self.selected_idx {
|
||||
if Some(idx) == self.selected_idx {
|
||||
Span::styled(
|
||||
"›",
|
||||
Style::default().bg(Color::DarkGray).fg(Color::LightCyan),
|
||||
@@ -241,67 +188,3 @@ impl WidgetRef for CommandPopup {
|
||||
table.render(area, buf);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::Rect;
|
||||
|
||||
#[test]
|
||||
fn move_down_wraps_to_top() {
|
||||
let mut popup = CommandPopup::new();
|
||||
// Show all commands by simulating composer input starting with '/'.
|
||||
popup.on_composer_text_change("/".to_string());
|
||||
let len = popup.filtered_commands().len();
|
||||
assert!(len > 0);
|
||||
|
||||
// Move to last item.
|
||||
for _ in 0..len.saturating_sub(1) {
|
||||
popup.move_down();
|
||||
}
|
||||
// Next move_down should wrap to index 0.
|
||||
popup.move_down();
|
||||
assert_eq!(popup.selected_idx, Some(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn move_up_wraps_to_bottom() {
|
||||
let mut popup = CommandPopup::new();
|
||||
popup.on_composer_text_change("/".to_string());
|
||||
let len = popup.filtered_commands().len();
|
||||
assert!(len > 0);
|
||||
|
||||
// Initial selection is 0; moving up should wrap to last.
|
||||
popup.move_up();
|
||||
assert_eq!(popup.selected_idx, Some(len - 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn respects_tiny_terminal_height_when_rendering() {
|
||||
let mut popup = CommandPopup::new();
|
||||
popup.on_composer_text_change("/".to_string());
|
||||
assert!(popup.filtered_commands().len() >= 3);
|
||||
|
||||
let area = Rect::new(0, 0, 50, 2);
|
||||
let mut buf = Buffer::empty(area);
|
||||
popup.render(area, &mut buf);
|
||||
|
||||
let mut non_empty_rows = 0u16;
|
||||
for y in 0..area.height {
|
||||
let mut row_has_content = false;
|
||||
for x in 0..area.width {
|
||||
let c = buf[(x, y)].symbol();
|
||||
if !c.trim().is_empty() {
|
||||
row_has_content = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if row_has_content {
|
||||
non_empty_rows += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(non_empty_rows, 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +108,7 @@ impl FileSearchPopup {
|
||||
.map(|file_match| file_match.path.as_str())
|
||||
}
|
||||
|
||||
/// Preferred height (rows) including border.
|
||||
pub(crate) fn calculate_required_height(&self) -> u16 {
|
||||
// Row count depends on whether we already have matches. If no matches
|
||||
// yet (e.g. initial search or query with no results) reserve a single
|
||||
|
||||
@@ -374,6 +374,7 @@ impl ChatWidget<'_> {
|
||||
);
|
||||
self.add_to_history(HistoryCell::new_active_exec_command(command));
|
||||
}
|
||||
EventMsg::ExecCommandOutputDelta(_) => {}
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id: _,
|
||||
auto_approved,
|
||||
|
||||
@@ -545,24 +545,17 @@ impl HistoryCell {
|
||||
} else {
|
||||
for (idx, PlanItemArg { step, status }) in plan.into_iter().enumerate() {
|
||||
let num = idx + 1;
|
||||
let (icon, style): (&str, Style) = match status {
|
||||
StepStatus::Completed => ("✓", Style::default().fg(Color::Green)),
|
||||
StepStatus::InProgress => (
|
||||
"▶",
|
||||
Style::default()
|
||||
.fg(Color::Yellow)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
),
|
||||
StepStatus::Pending => ("○", Style::default().fg(Color::Gray)),
|
||||
let icon_span: Span = match status {
|
||||
StepStatus::Completed => Span::from("✓").fg(Color::Green),
|
||||
StepStatus::InProgress => Span::from("▶").fg(Color::Yellow).bold(),
|
||||
StepStatus::Pending => Span::from("○").fg(Color::Gray),
|
||||
};
|
||||
let prefix = vec![
|
||||
Span::raw(format!("{num:>2}. [")),
|
||||
Span::styled(icon.to_string(), style),
|
||||
Span::raw("] "),
|
||||
];
|
||||
let mut spans = prefix;
|
||||
spans.push(Span::raw(step));
|
||||
lines.push(Line::from(spans));
|
||||
lines.push(Line::from(vec![
|
||||
format!("{num:>2}. [").into(),
|
||||
icon_span,
|
||||
"] ".into(),
|
||||
step.into(),
|
||||
]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -216,18 +216,18 @@ where
|
||||
{
|
||||
let mut fg = Color::Reset;
|
||||
let mut bg = Color::Reset;
|
||||
let mut modifier = Modifier::empty();
|
||||
let mut last_modifier = Modifier::empty();
|
||||
for span in content {
|
||||
let mut next_modifier = modifier;
|
||||
next_modifier.insert(span.style.add_modifier);
|
||||
next_modifier.remove(span.style.sub_modifier);
|
||||
if next_modifier != modifier {
|
||||
let mut modifier = Modifier::empty();
|
||||
modifier.insert(span.style.add_modifier);
|
||||
modifier.remove(span.style.sub_modifier);
|
||||
if modifier != last_modifier {
|
||||
let diff = ModifierDiff {
|
||||
from: modifier,
|
||||
to: next_modifier,
|
||||
from: last_modifier,
|
||||
to: modifier,
|
||||
};
|
||||
diff.queue(&mut writer)?;
|
||||
modifier = next_modifier;
|
||||
last_modifier = modifier;
|
||||
}
|
||||
let next_fg = span.style.fg.unwrap_or(Color::Reset);
|
||||
let next_bg = span.style.bg.unwrap_or(Color::Reset);
|
||||
@@ -250,3 +250,37 @@ where
|
||||
SetAttribute(crossterm::style::Attribute::Reset),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn writes_bold_then_regular_spans() {
|
||||
use ratatui::style::Stylize;
|
||||
|
||||
let spans = ["A".bold(), "B".into()];
|
||||
|
||||
let mut actual: Vec<u8> = Vec::new();
|
||||
write_spans(&mut actual, spans.iter()).unwrap();
|
||||
|
||||
let mut expected: Vec<u8> = Vec::new();
|
||||
queue!(
|
||||
expected,
|
||||
SetAttribute(crossterm::style::Attribute::Bold),
|
||||
Print("A"),
|
||||
SetAttribute(crossterm::style::Attribute::NormalIntensity),
|
||||
Print("B"),
|
||||
SetForegroundColor(CColor::Reset),
|
||||
SetBackgroundColor(CColor::Reset),
|
||||
SetAttribute(crossterm::style::Attribute::Reset),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
String::from_utf8(actual).unwrap(),
|
||||
String::from_utf8(expected).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,11 @@ mod text_formatting;
|
||||
mod tui;
|
||||
mod user_approval_widget;
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
mod updates;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use color_eyre::owo_colors::OwoColorize;
|
||||
|
||||
pub use cli::Cli;
|
||||
|
||||
pub async fn run_main(
|
||||
@@ -139,6 +144,38 @@ pub async fn run_main(
|
||||
.with(tui_layer)
|
||||
.try_init();
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
#[cfg(not(debug_assertions))]
|
||||
if let Some(latest_version) = updates::get_upgrade_version(&config) {
|
||||
let current_version = env!("CARGO_PKG_VERSION");
|
||||
let exe = std::env::current_exe()?;
|
||||
let managed_by_npm = std::env::var_os("CODEX_MANAGED_BY_NPM").is_some();
|
||||
|
||||
eprintln!(
|
||||
"{} {current_version} -> {latest_version}.",
|
||||
"✨⬆️ Update available!".bold().cyan()
|
||||
);
|
||||
|
||||
if managed_by_npm {
|
||||
let npm_cmd = "npm install -g @openai/codex@latest";
|
||||
eprintln!("Run {} to update.", npm_cmd.cyan().on_black());
|
||||
} else if cfg!(target_os = "macos")
|
||||
&& (exe.starts_with("/opt/homebrew") || exe.starts_with("/usr/local"))
|
||||
{
|
||||
let brew_cmd = "brew upgrade codex";
|
||||
eprintln!("Run {} to update.", brew_cmd.cyan().on_black());
|
||||
} else {
|
||||
eprintln!(
|
||||
"See {} for the latest releases and installation options.",
|
||||
"https://github.com/openai/codex/releases/latest"
|
||||
.cyan()
|
||||
.on_black()
|
||||
);
|
||||
}
|
||||
|
||||
eprintln!("");
|
||||
}
|
||||
|
||||
let show_login_screen = should_show_login_screen(&config);
|
||||
if show_login_screen {
|
||||
std::io::stdout()
|
||||
|
||||
137
codex-rs/tui/src/updates.rs
Normal file
137
codex-rs/tui/src/updates.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
#![cfg(any(not(debug_assertions), test))]
|
||||
|
||||
use chrono::DateTime;
|
||||
use chrono::Duration;
|
||||
use chrono::Utc;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::config::Config;
|
||||
|
||||
pub fn get_upgrade_version(config: &Config) -> Option<String> {
|
||||
let version_file = version_filepath(config);
|
||||
let info = read_version_info(&version_file).ok();
|
||||
|
||||
if match &info {
|
||||
None => true,
|
||||
Some(info) => info.last_checked_at < Utc::now() - Duration::hours(20),
|
||||
} {
|
||||
// Refresh the cached latest version in the background so TUI startup
|
||||
// isn’t blocked by a network call. The UI reads the previously cached
|
||||
// value (if any) for this run; the next run shows the banner if needed.
|
||||
tokio::spawn(async move {
|
||||
check_for_update(&version_file)
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!("Failed to update version: {e}"))
|
||||
});
|
||||
}
|
||||
|
||||
info.and_then(|info| {
|
||||
let current_version = env!("CARGO_PKG_VERSION");
|
||||
if is_newer(&info.latest_version, current_version).unwrap_or(false) {
|
||||
Some(info.latest_version)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct VersionInfo {
|
||||
latest_version: String,
|
||||
// ISO-8601 timestamp (RFC3339)
|
||||
last_checked_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct ReleaseInfo {
|
||||
tag_name: String,
|
||||
}
|
||||
|
||||
const VERSION_FILENAME: &str = "version.json";
|
||||
const LATEST_RELEASE_URL: &str = "https://api.github.com/repos/openai/codex/releases/latest";
|
||||
|
||||
fn version_filepath(config: &Config) -> PathBuf {
|
||||
config.codex_home.join(VERSION_FILENAME)
|
||||
}
|
||||
|
||||
fn read_version_info(version_file: &Path) -> anyhow::Result<VersionInfo> {
|
||||
let contents = std::fs::read_to_string(version_file)?;
|
||||
Ok(serde_json::from_str(&contents)?)
|
||||
}
|
||||
|
||||
async fn check_for_update(version_file: &Path) -> anyhow::Result<()> {
|
||||
let ReleaseInfo {
|
||||
tag_name: latest_tag_name,
|
||||
} = reqwest::Client::new()
|
||||
.get(LATEST_RELEASE_URL)
|
||||
.header(
|
||||
"User-Agent",
|
||||
format!(
|
||||
"codex/{} (+https://github.com/openai/codex)",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
),
|
||||
)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json::<ReleaseInfo>()
|
||||
.await?;
|
||||
|
||||
let info = VersionInfo {
|
||||
latest_version: latest_tag_name
|
||||
.strip_prefix("rust-v")
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to parse latest tag name '{latest_tag_name}'"))?
|
||||
.into(),
|
||||
last_checked_at: Utc::now(),
|
||||
};
|
||||
|
||||
let json_line = format!("{}\n", serde_json::to_string(&info)?);
|
||||
if let Some(parent) = version_file.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
tokio::fs::write(version_file, json_line).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_newer(latest: &str, current: &str) -> Option<bool> {
|
||||
match (parse_version(latest), parse_version(current)) {
|
||||
(Some(l), Some(c)) => Some(l > c),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_version(v: &str) -> Option<(u64, u64, u64)> {
|
||||
let mut iter = v.trim().split('.');
|
||||
let maj = iter.next()?.parse::<u64>().ok()?;
|
||||
let min = iter.next()?.parse::<u64>().ok()?;
|
||||
let pat = iter.next()?.parse::<u64>().ok()?;
|
||||
Some((maj, min, pat))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn prerelease_version_is_not_considered_newer() {
|
||||
assert_eq!(is_newer("0.11.0-beta.1", "0.11.0"), None);
|
||||
assert_eq!(is_newer("1.0.0-rc.1", "1.0.0"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plain_semver_comparisons_work() {
|
||||
assert_eq!(is_newer("0.11.1", "0.11.0"), Some(true));
|
||||
assert_eq!(is_newer("0.11.0", "0.11.1"), Some(false));
|
||||
assert_eq!(is_newer("1.0.0", "0.9.9"), Some(true));
|
||||
assert_eq!(is_newer("0.9.9", "1.0.0"), Some(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn whitespace_is_ignored() {
|
||||
assert_eq!(parse_version(" 1.2.3 \n"), Some((1, 2, 3)));
|
||||
assert_eq!(is_newer(" 1.2.3 ", "1.2.2"), Some(true));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user