mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
feat: tie shell snapshot to cwd (#11231)
Fix for this: https://github.com/openai/codex/issues/11223 Basically we tie the shell snapshot to a `cwd` to handle `cwd`-based env setups
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Debug;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
@@ -87,6 +88,7 @@ use tokio::sync::Mutex;
|
||||
use tokio::sync::OnceCell;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::Instrument;
|
||||
use tracing::debug;
|
||||
@@ -234,8 +236,6 @@ use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_utils_readiness::Readiness;
|
||||
use codex_utils_readiness::ReadinessFlag;
|
||||
use tokio::sync::watch;
|
||||
|
||||
/// The high-level interface to the Codex system.
|
||||
/// It operates as a queue pair where you send submissions and receive events.
|
||||
pub struct Codex {
|
||||
@@ -1033,14 +1033,19 @@ impl Session {
|
||||
|
||||
let mut default_shell = shell::default_user_shell();
|
||||
// Create the mutable state for the Session.
|
||||
if config.features.enabled(Feature::ShellSnapshot) {
|
||||
let shell_snapshot_tx = if config.features.enabled(Feature::ShellSnapshot) {
|
||||
ShellSnapshot::start_snapshotting(
|
||||
config.codex_home.clone(),
|
||||
conversation_id,
|
||||
session_configuration.cwd.clone(),
|
||||
&mut default_shell,
|
||||
otel_manager.clone(),
|
||||
);
|
||||
}
|
||||
)
|
||||
} else {
|
||||
let (tx, rx) = watch::channel(None);
|
||||
default_shell.shell_snapshot = rx;
|
||||
tx
|
||||
};
|
||||
let thread_name =
|
||||
match session_index::find_thread_name_by_id(&config.codex_home, &conversation_id).await
|
||||
{
|
||||
@@ -1064,6 +1069,7 @@ impl Session {
|
||||
hooks: Hooks::new(config.as_ref()),
|
||||
rollout: Mutex::new(rollout_recorder),
|
||||
user_shell: Arc::new(default_shell),
|
||||
shell_snapshot_tx,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
exec_policy,
|
||||
auth_manager: Arc::clone(&auth_manager),
|
||||
@@ -1405,6 +1411,30 @@ impl Session {
|
||||
state.pending_resume_previous_model.take()
|
||||
}
|
||||
|
||||
fn maybe_refresh_shell_snapshot_for_cwd(
|
||||
&self,
|
||||
previous_cwd: &Path,
|
||||
next_cwd: &Path,
|
||||
codex_home: &Path,
|
||||
) {
|
||||
if previous_cwd == next_cwd {
|
||||
return;
|
||||
}
|
||||
|
||||
if !self.features.enabled(Feature::ShellSnapshot) {
|
||||
return;
|
||||
}
|
||||
|
||||
ShellSnapshot::refresh_snapshot(
|
||||
codex_home.to_path_buf(),
|
||||
self.conversation_id,
|
||||
next_cwd.to_path_buf(),
|
||||
self.services.user_shell.as_ref().clone(),
|
||||
self.services.shell_snapshot_tx.clone(),
|
||||
self.services.otel_manager.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) async fn update_settings(
|
||||
&self,
|
||||
updates: SessionSettingsUpdate,
|
||||
@@ -1413,7 +1443,14 @@ impl Session {
|
||||
|
||||
match state.session_configuration.apply(&updates) {
|
||||
Ok(updated) => {
|
||||
let previous_cwd = state.session_configuration.cwd.clone();
|
||||
let next_cwd = updated.cwd.clone();
|
||||
let codex_home = updated.codex_home.clone();
|
||||
state.session_configuration = updated;
|
||||
drop(state);
|
||||
|
||||
self.maybe_refresh_shell_snapshot_for_cwd(&previous_cwd, &next_cwd, &codex_home);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
@@ -1428,14 +1465,16 @@ impl Session {
|
||||
sub_id: String,
|
||||
updates: SessionSettingsUpdate,
|
||||
) -> ConstraintResult<Arc<TurnContext>> {
|
||||
let (session_configuration, sandbox_policy_changed) = {
|
||||
let (session_configuration, sandbox_policy_changed, previous_cwd, codex_home) = {
|
||||
let mut state = self.state.lock().await;
|
||||
match state.session_configuration.clone().apply(&updates) {
|
||||
Ok(next) => {
|
||||
let previous_cwd = state.session_configuration.cwd.clone();
|
||||
let sandbox_policy_changed =
|
||||
state.session_configuration.sandbox_policy != next.sandbox_policy;
|
||||
let codex_home = next.codex_home.clone();
|
||||
state.session_configuration = next.clone();
|
||||
(next, sandbox_policy_changed)
|
||||
(next, sandbox_policy_changed, previous_cwd, codex_home)
|
||||
}
|
||||
Err(err) => {
|
||||
drop(state);
|
||||
@@ -1452,6 +1491,12 @@ impl Session {
|
||||
}
|
||||
};
|
||||
|
||||
self.maybe_refresh_shell_snapshot_for_cwd(
|
||||
&previous_cwd,
|
||||
&session_configuration.cwd,
|
||||
&codex_home,
|
||||
);
|
||||
|
||||
Ok(self
|
||||
.new_turn_from_configuration(
|
||||
sub_id,
|
||||
@@ -6123,6 +6168,7 @@ mod tests {
|
||||
hooks: Hooks::new(&config),
|
||||
rollout: Mutex::new(None),
|
||||
user_shell: Arc::new(default_user_shell()),
|
||||
shell_snapshot_tx: watch::channel(None).0,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
exec_policy,
|
||||
auth_manager: auth_manager.clone(),
|
||||
@@ -6255,6 +6301,7 @@ mod tests {
|
||||
hooks: Hooks::new(&config),
|
||||
rollout: Mutex::new(None),
|
||||
user_shell: Arc::new(default_user_shell()),
|
||||
shell_snapshot_tx: watch::channel(None).0,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
exec_policy,
|
||||
auth_manager: Arc::clone(&auth_manager),
|
||||
|
||||
@@ -26,6 +26,7 @@ use tracing::info_span;
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ShellSnapshot {
|
||||
pub path: PathBuf,
|
||||
pub cwd: PathBuf,
|
||||
}
|
||||
|
||||
const SNAPSHOT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
@@ -37,22 +38,63 @@ impl ShellSnapshot {
|
||||
pub fn start_snapshotting(
|
||||
codex_home: PathBuf,
|
||||
session_id: ThreadId,
|
||||
session_cwd: PathBuf,
|
||||
shell: &mut Shell,
|
||||
otel_manager: OtelManager,
|
||||
) {
|
||||
) -> watch::Sender<Option<Arc<ShellSnapshot>>> {
|
||||
let (shell_snapshot_tx, shell_snapshot_rx) = watch::channel(None);
|
||||
shell.shell_snapshot = shell_snapshot_rx;
|
||||
|
||||
let snapshot_shell = shell.clone();
|
||||
let snapshot_session_id = session_id;
|
||||
let snapshot_span = info_span!("shell_snapshot", thread_id = %snapshot_session_id);
|
||||
Self::spawn_snapshot_task(
|
||||
codex_home,
|
||||
session_id,
|
||||
session_cwd,
|
||||
shell.clone(),
|
||||
shell_snapshot_tx.clone(),
|
||||
otel_manager,
|
||||
);
|
||||
|
||||
shell_snapshot_tx
|
||||
}
|
||||
|
||||
pub fn refresh_snapshot(
|
||||
codex_home: PathBuf,
|
||||
session_id: ThreadId,
|
||||
session_cwd: PathBuf,
|
||||
shell: Shell,
|
||||
shell_snapshot_tx: watch::Sender<Option<Arc<ShellSnapshot>>>,
|
||||
otel_manager: OtelManager,
|
||||
) {
|
||||
Self::spawn_snapshot_task(
|
||||
codex_home,
|
||||
session_id,
|
||||
session_cwd,
|
||||
shell,
|
||||
shell_snapshot_tx,
|
||||
otel_manager,
|
||||
);
|
||||
}
|
||||
|
||||
fn spawn_snapshot_task(
|
||||
codex_home: PathBuf,
|
||||
session_id: ThreadId,
|
||||
session_cwd: PathBuf,
|
||||
snapshot_shell: Shell,
|
||||
shell_snapshot_tx: watch::Sender<Option<Arc<ShellSnapshot>>>,
|
||||
otel_manager: OtelManager,
|
||||
) {
|
||||
let snapshot_span = info_span!("shell_snapshot", thread_id = %session_id);
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let timer = otel_manager.start_timer("codex.shell_snapshot.duration_ms", &[]);
|
||||
let snapshot =
|
||||
ShellSnapshot::try_new(&codex_home, snapshot_session_id, &snapshot_shell)
|
||||
.await
|
||||
.map(Arc::new);
|
||||
let snapshot = ShellSnapshot::try_new(
|
||||
&codex_home,
|
||||
session_id,
|
||||
session_cwd.as_path(),
|
||||
&snapshot_shell,
|
||||
)
|
||||
.await
|
||||
.map(Arc::new);
|
||||
let success = if snapshot.is_some() { "true" } else { "false" };
|
||||
let _ = timer.map(|timer| timer.record(&[("success", success)]));
|
||||
otel_manager.counter("codex.shell_snapshot", 1, &[("success", success)]);
|
||||
@@ -62,7 +104,12 @@ impl ShellSnapshot {
|
||||
);
|
||||
}
|
||||
|
||||
async fn try_new(codex_home: &Path, session_id: ThreadId, shell: &Shell) -> Option<Self> {
|
||||
async fn try_new(
|
||||
codex_home: &Path,
|
||||
session_id: ThreadId,
|
||||
session_cwd: &Path,
|
||||
shell: &Shell,
|
||||
) -> Option<Self> {
|
||||
// File to store the snapshot
|
||||
let extension = match shell.shell_type {
|
||||
ShellType::PowerShell => "ps1",
|
||||
@@ -82,22 +129,26 @@ impl ShellSnapshot {
|
||||
});
|
||||
|
||||
// Make the new snapshot.
|
||||
let snapshot = match write_shell_snapshot(shell.shell_type.clone(), &path).await {
|
||||
Ok(path) => {
|
||||
tracing::info!("Shell snapshot successfully created: {}", path.display());
|
||||
Some(Self { path })
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"Failed to create shell snapshot for {}: {err:?}",
|
||||
shell.name()
|
||||
);
|
||||
None
|
||||
}
|
||||
};
|
||||
let snapshot =
|
||||
match write_shell_snapshot(shell.shell_type.clone(), &path, session_cwd).await {
|
||||
Ok(path) => {
|
||||
tracing::info!("Shell snapshot successfully created: {}", path.display());
|
||||
Some(Self {
|
||||
path,
|
||||
cwd: session_cwd.to_path_buf(),
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"Failed to create shell snapshot for {}: {err:?}",
|
||||
shell.name()
|
||||
);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(snapshot) = snapshot.as_ref()
|
||||
&& let Err(err) = validate_snapshot(shell, &snapshot.path).await
|
||||
&& let Err(err) = validate_snapshot(shell, &snapshot.path, session_cwd).await
|
||||
{
|
||||
tracing::error!("Shell snapshot validation failed: {err:?}");
|
||||
return None;
|
||||
@@ -118,14 +169,18 @@ impl Drop for ShellSnapshot {
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_shell_snapshot(shell_type: ShellType, output_path: &Path) -> Result<PathBuf> {
|
||||
async fn write_shell_snapshot(
|
||||
shell_type: ShellType,
|
||||
output_path: &Path,
|
||||
cwd: &Path,
|
||||
) -> Result<PathBuf> {
|
||||
if shell_type == ShellType::PowerShell || shell_type == ShellType::Cmd {
|
||||
bail!("Shell snapshot not supported yet for {shell_type:?}");
|
||||
}
|
||||
let shell = get_shell(shell_type.clone(), None)
|
||||
.with_context(|| format!("No available shell for {shell_type:?}"))?;
|
||||
|
||||
let raw_snapshot = capture_snapshot(&shell).await?;
|
||||
let raw_snapshot = capture_snapshot(&shell, cwd).await?;
|
||||
let snapshot = strip_snapshot_preamble(&raw_snapshot)?;
|
||||
|
||||
if let Some(parent) = output_path.parent() {
|
||||
@@ -143,13 +198,13 @@ async fn write_shell_snapshot(shell_type: ShellType, output_path: &Path) -> Resu
|
||||
Ok(output_path.to_path_buf())
|
||||
}
|
||||
|
||||
async fn capture_snapshot(shell: &Shell) -> Result<String> {
|
||||
async fn capture_snapshot(shell: &Shell, cwd: &Path) -> Result<String> {
|
||||
let shell_type = shell.shell_type.clone();
|
||||
match shell_type {
|
||||
ShellType::Zsh => run_shell_script(shell, &zsh_snapshot_script()).await,
|
||||
ShellType::Bash => run_shell_script(shell, &bash_snapshot_script()).await,
|
||||
ShellType::Sh => run_shell_script(shell, &sh_snapshot_script()).await,
|
||||
ShellType::PowerShell => run_shell_script(shell, powershell_snapshot_script()).await,
|
||||
ShellType::Zsh => run_shell_script(shell, &zsh_snapshot_script(), cwd).await,
|
||||
ShellType::Bash => run_shell_script(shell, &bash_snapshot_script(), cwd).await,
|
||||
ShellType::Sh => run_shell_script(shell, &sh_snapshot_script(), cwd).await,
|
||||
ShellType::PowerShell => run_shell_script(shell, powershell_snapshot_script(), cwd).await,
|
||||
ShellType::Cmd => bail!("Shell snapshotting is not yet supported for {shell_type:?}"),
|
||||
}
|
||||
}
|
||||
@@ -163,16 +218,16 @@ fn strip_snapshot_preamble(snapshot: &str) -> Result<String> {
|
||||
Ok(snapshot[start..].to_string())
|
||||
}
|
||||
|
||||
async fn validate_snapshot(shell: &Shell, snapshot_path: &Path) -> Result<()> {
|
||||
async fn validate_snapshot(shell: &Shell, snapshot_path: &Path, cwd: &Path) -> Result<()> {
|
||||
let snapshot_path_display = snapshot_path.display();
|
||||
let script = format!("set -e; . \"{snapshot_path_display}\"");
|
||||
run_script_with_timeout(shell, &script, SNAPSHOT_TIMEOUT, false)
|
||||
run_script_with_timeout(shell, &script, SNAPSHOT_TIMEOUT, false, cwd)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
async fn run_shell_script(shell: &Shell, script: &str) -> Result<String> {
|
||||
run_script_with_timeout(shell, script, SNAPSHOT_TIMEOUT, true).await
|
||||
async fn run_shell_script(shell: &Shell, script: &str, cwd: &Path) -> Result<String> {
|
||||
run_script_with_timeout(shell, script, SNAPSHOT_TIMEOUT, true, cwd).await
|
||||
}
|
||||
|
||||
async fn run_script_with_timeout(
|
||||
@@ -180,6 +235,7 @@ async fn run_script_with_timeout(
|
||||
script: &str,
|
||||
snapshot_timeout: Duration,
|
||||
use_login_shell: bool,
|
||||
cwd: &Path,
|
||||
) -> Result<String> {
|
||||
let args = shell.derive_exec_args(script, use_login_shell);
|
||||
let shell_name = shell.name();
|
||||
@@ -189,6 +245,7 @@ async fn run_script_with_timeout(
|
||||
let mut handler = Command::new(&args[0]);
|
||||
handler.args(&args[1..]);
|
||||
handler.stdin(Stdio::null());
|
||||
handler.current_dir(cwd);
|
||||
#[cfg(unix)]
|
||||
unsafe {
|
||||
handler.pre_exec(|| {
|
||||
@@ -550,7 +607,7 @@ mod tests {
|
||||
async fn get_snapshot(shell_type: ShellType) -> Result<String> {
|
||||
let dir = tempdir()?;
|
||||
let path = dir.path().join("snapshot.sh");
|
||||
write_shell_snapshot(shell_type, &path).await?;
|
||||
write_shell_snapshot(shell_type, &path, dir.path()).await?;
|
||||
let content = fs::read_to_string(&path).await?;
|
||||
Ok(content)
|
||||
}
|
||||
@@ -602,11 +659,12 @@ mod tests {
|
||||
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
|
||||
};
|
||||
|
||||
let snapshot = ShellSnapshot::try_new(dir.path(), ThreadId::new(), &shell)
|
||||
let snapshot = ShellSnapshot::try_new(dir.path(), ThreadId::new(), dir.path(), &shell)
|
||||
.await
|
||||
.expect("snapshot should be created");
|
||||
let path = snapshot.path.clone();
|
||||
assert!(path.exists());
|
||||
assert_eq!(snapshot.cwd, dir.path().to_path_buf());
|
||||
|
||||
drop(snapshot);
|
||||
|
||||
@@ -635,9 +693,10 @@ mod tests {
|
||||
"HOME=\"{home_display}\"; export HOME; {}",
|
||||
bash_snapshot_script()
|
||||
);
|
||||
let output = run_script_with_timeout(&shell, &script, Duration::from_millis(500), true)
|
||||
.await
|
||||
.context("run snapshot command")?;
|
||||
let output =
|
||||
run_script_with_timeout(&shell, &script, Duration::from_millis(500), true, home)
|
||||
.await
|
||||
.context("run snapshot command")?;
|
||||
|
||||
assert!(
|
||||
output.contains("# Snapshot file"),
|
||||
@@ -665,9 +724,10 @@ mod tests {
|
||||
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
|
||||
};
|
||||
|
||||
let err = run_script_with_timeout(&shell, &script, Duration::from_secs(1), true)
|
||||
.await
|
||||
.expect_err("snapshot shell should time out");
|
||||
let err =
|
||||
run_script_with_timeout(&shell, &script, Duration::from_secs(1), true, dir.path())
|
||||
.await
|
||||
.expect_err("snapshot shell should time out");
|
||||
assert!(
|
||||
err.to_string().contains("timed out"),
|
||||
"expected timeout error, got {err:?}"
|
||||
|
||||
@@ -17,6 +17,7 @@ use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use codex_otel::OtelManager;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::watch;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub(crate) struct SessionServices {
|
||||
@@ -27,6 +28,7 @@ pub(crate) struct SessionServices {
|
||||
pub(crate) hooks: Hooks,
|
||||
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
|
||||
pub(crate) user_shell: Arc<crate::shell::Shell>,
|
||||
pub(crate) shell_snapshot_tx: watch::Sender<Option<Arc<crate::shell_snapshot::ShellSnapshot>>>,
|
||||
pub(crate) show_raw_agent_reasoning: bool,
|
||||
pub(crate) exec_policy: ExecPolicyManager,
|
||||
pub(crate) auth_manager: Arc<AuthManager>,
|
||||
|
||||
@@ -113,7 +113,11 @@ pub(crate) async fn execute_user_shell_command(
|
||||
let use_login_shell = true;
|
||||
let session_shell = session.user_shell();
|
||||
let display_command = session_shell.derive_exec_args(&command, use_login_shell);
|
||||
let exec_command = maybe_wrap_shell_lc_with_snapshot(&display_command, session_shell.as_ref());
|
||||
let exec_command = maybe_wrap_shell_lc_with_snapshot(
|
||||
&display_command,
|
||||
session_shell.as_ref(),
|
||||
turn_context.cwd.as_path(),
|
||||
);
|
||||
|
||||
let call_id = Uuid::new_v4().to_string();
|
||||
let raw_command = command;
|
||||
|
||||
@@ -455,6 +455,7 @@ mod tests {
|
||||
fn shell_command_handler_respects_explicit_login_flag() {
|
||||
let (_tx, shell_snapshot) = watch::channel(Some(Arc::new(ShellSnapshot {
|
||||
path: PathBuf::from("/tmp/snapshot.sh"),
|
||||
cwd: PathBuf::from("/tmp"),
|
||||
})));
|
||||
let shell = Shell {
|
||||
shell_type: ShellType::Bash,
|
||||
|
||||
@@ -5,6 +5,7 @@ Concrete ToolRuntime implementations for specific tools. Each runtime stays
|
||||
small and focused and reuses the orchestrator for approvals + sandbox + retry.
|
||||
*/
|
||||
use crate::exec::ExecExpiration;
|
||||
use crate::path_utils;
|
||||
use crate::sandboxing::CommandSpec;
|
||||
use crate::sandboxing::SandboxPermissions;
|
||||
use crate::shell::Shell;
|
||||
@@ -50,10 +51,12 @@ pub(crate) fn build_command_spec(
|
||||
/// => user_shell -c ". SNAPSHOT (best effort); exec shell -c <script>"
|
||||
///
|
||||
/// This wrapper script uses POSIX constructs (`if`, `.`, `exec`) so it can
|
||||
/// be run by Bash/Zsh/sh. On non-matching commands this is a no-op.
|
||||
/// be run by Bash/Zsh/sh. On non-matching commands, or when command cwd does
|
||||
/// not match the snapshot cwd, this is a no-op.
|
||||
pub(crate) fn maybe_wrap_shell_lc_with_snapshot(
|
||||
command: &[String],
|
||||
session_shell: &Shell,
|
||||
cwd: &Path,
|
||||
) -> Vec<String> {
|
||||
let Some(snapshot) = session_shell.shell_snapshot() else {
|
||||
return command.to_vec();
|
||||
@@ -63,6 +66,17 @@ pub(crate) fn maybe_wrap_shell_lc_with_snapshot(
|
||||
return command.to_vec();
|
||||
}
|
||||
|
||||
if if let (Ok(snapshot_cwd), Ok(command_cwd)) = (
|
||||
path_utils::normalize_for_path_comparison(snapshot.cwd.as_path()),
|
||||
path_utils::normalize_for_path_comparison(cwd),
|
||||
) {
|
||||
snapshot_cwd != command_cwd
|
||||
} else {
|
||||
snapshot.cwd != cwd
|
||||
} {
|
||||
return command.to_vec();
|
||||
}
|
||||
|
||||
if command.len() < 3 {
|
||||
return command.to_vec();
|
||||
}
|
||||
@@ -107,9 +121,11 @@ mod tests {
|
||||
shell_type: ShellType,
|
||||
shell_path: &str,
|
||||
snapshot_path: PathBuf,
|
||||
snapshot_cwd: PathBuf,
|
||||
) -> Shell {
|
||||
let (_tx, shell_snapshot) = watch::channel(Some(Arc::new(ShellSnapshot {
|
||||
path: snapshot_path,
|
||||
cwd: snapshot_cwd,
|
||||
})));
|
||||
Shell {
|
||||
shell_type,
|
||||
@@ -123,14 +139,19 @@ mod tests {
|
||||
let dir = tempdir().expect("create temp dir");
|
||||
let snapshot_path = dir.path().join("snapshot.sh");
|
||||
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
|
||||
let session_shell = shell_with_snapshot(ShellType::Zsh, "/bin/zsh", snapshot_path);
|
||||
let session_shell = shell_with_snapshot(
|
||||
ShellType::Zsh,
|
||||
"/bin/zsh",
|
||||
snapshot_path,
|
||||
dir.path().to_path_buf(),
|
||||
);
|
||||
let command = vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo hello".to_string(),
|
||||
];
|
||||
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell);
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path());
|
||||
|
||||
assert_eq!(rewritten[0], "/bin/zsh");
|
||||
assert_eq!(rewritten[1], "-c");
|
||||
@@ -143,14 +164,19 @@ mod tests {
|
||||
let dir = tempdir().expect("create temp dir");
|
||||
let snapshot_path = dir.path().join("snapshot.sh");
|
||||
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
|
||||
let session_shell = shell_with_snapshot(ShellType::Zsh, "/bin/zsh", snapshot_path);
|
||||
let session_shell = shell_with_snapshot(
|
||||
ShellType::Zsh,
|
||||
"/bin/zsh",
|
||||
snapshot_path,
|
||||
dir.path().to_path_buf(),
|
||||
);
|
||||
let command = vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo 'hello'".to_string(),
|
||||
];
|
||||
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell);
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path());
|
||||
|
||||
assert!(rewritten[2].contains(r#"exec '/bin/bash' -c 'echo '"'"'hello'"'"''"#));
|
||||
}
|
||||
@@ -160,14 +186,19 @@ mod tests {
|
||||
let dir = tempdir().expect("create temp dir");
|
||||
let snapshot_path = dir.path().join("snapshot.sh");
|
||||
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
|
||||
let session_shell = shell_with_snapshot(ShellType::Bash, "/bin/bash", snapshot_path);
|
||||
let session_shell = shell_with_snapshot(
|
||||
ShellType::Bash,
|
||||
"/bin/bash",
|
||||
snapshot_path,
|
||||
dir.path().to_path_buf(),
|
||||
);
|
||||
let command = vec![
|
||||
"/bin/zsh".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo hello".to_string(),
|
||||
];
|
||||
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell);
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path());
|
||||
|
||||
assert_eq!(rewritten[0], "/bin/bash");
|
||||
assert_eq!(rewritten[1], "-c");
|
||||
@@ -180,14 +211,19 @@ mod tests {
|
||||
let dir = tempdir().expect("create temp dir");
|
||||
let snapshot_path = dir.path().join("snapshot.sh");
|
||||
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
|
||||
let session_shell = shell_with_snapshot(ShellType::Sh, "/bin/sh", snapshot_path);
|
||||
let session_shell = shell_with_snapshot(
|
||||
ShellType::Sh,
|
||||
"/bin/sh",
|
||||
snapshot_path,
|
||||
dir.path().to_path_buf(),
|
||||
);
|
||||
let command = vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo hello".to_string(),
|
||||
];
|
||||
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell);
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path());
|
||||
|
||||
assert_eq!(rewritten[0], "/bin/sh");
|
||||
assert_eq!(rewritten[1], "-c");
|
||||
@@ -200,7 +236,12 @@ mod tests {
|
||||
let dir = tempdir().expect("create temp dir");
|
||||
let snapshot_path = dir.path().join("snapshot.sh");
|
||||
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
|
||||
let session_shell = shell_with_snapshot(ShellType::Zsh, "/bin/zsh", snapshot_path);
|
||||
let session_shell = shell_with_snapshot(
|
||||
ShellType::Zsh,
|
||||
"/bin/zsh",
|
||||
snapshot_path,
|
||||
dir.path().to_path_buf(),
|
||||
);
|
||||
let command = vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
@@ -209,7 +250,7 @@ mod tests {
|
||||
"arg1".to_string(),
|
||||
];
|
||||
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell);
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path());
|
||||
|
||||
assert!(
|
||||
rewritten[2].contains(
|
||||
@@ -217,4 +258,52 @@ mod tests {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_wrap_shell_lc_with_snapshot_skips_when_cwd_mismatch() {
|
||||
let dir = tempdir().expect("create temp dir");
|
||||
let snapshot_path = dir.path().join("snapshot.sh");
|
||||
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
|
||||
let snapshot_cwd = dir.path().join("worktree-a");
|
||||
let command_cwd = dir.path().join("worktree-b");
|
||||
std::fs::create_dir_all(&snapshot_cwd).expect("create snapshot cwd");
|
||||
std::fs::create_dir_all(&command_cwd).expect("create command cwd");
|
||||
let session_shell =
|
||||
shell_with_snapshot(ShellType::Zsh, "/bin/zsh", snapshot_path, snapshot_cwd);
|
||||
let command = vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo hello".to_string(),
|
||||
];
|
||||
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, &command_cwd);
|
||||
|
||||
assert_eq!(rewritten, command);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_wrap_shell_lc_with_snapshot_accepts_dot_alias_cwd() {
|
||||
let dir = tempdir().expect("create temp dir");
|
||||
let snapshot_path = dir.path().join("snapshot.sh");
|
||||
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
|
||||
let session_shell = shell_with_snapshot(
|
||||
ShellType::Zsh,
|
||||
"/bin/zsh",
|
||||
snapshot_path,
|
||||
dir.path().to_path_buf(),
|
||||
);
|
||||
let command = vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"echo hello".to_string(),
|
||||
];
|
||||
let command_cwd = dir.path().join(".");
|
||||
|
||||
let rewritten = maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, &command_cwd);
|
||||
|
||||
assert_eq!(rewritten[0], "/bin/zsh");
|
||||
assert_eq!(rewritten[1], "-c");
|
||||
assert!(rewritten[2].contains("if . '"));
|
||||
assert!(rewritten[2].contains("exec '/bin/bash' -c 'echo hello'"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,7 +148,8 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
|
||||
) -> Result<ExecToolCallOutput, ToolError> {
|
||||
let base_command = &req.command;
|
||||
let session_shell = ctx.session.user_shell();
|
||||
let command = maybe_wrap_shell_lc_with_snapshot(base_command, session_shell.as_ref());
|
||||
let command =
|
||||
maybe_wrap_shell_lc_with_snapshot(base_command, session_shell.as_ref(), &req.cwd);
|
||||
let command = if matches!(session_shell.shell_type, ShellType::PowerShell)
|
||||
&& ctx.session.features().enabled(Feature::PowershellUtf8)
|
||||
{
|
||||
|
||||
@@ -152,7 +152,8 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
|
||||
) -> Result<UnifiedExecProcess, ToolError> {
|
||||
let base_command = &req.command;
|
||||
let session_shell = ctx.session.user_shell();
|
||||
let command = maybe_wrap_shell_lc_with_snapshot(base_command, session_shell.as_ref());
|
||||
let command =
|
||||
maybe_wrap_shell_lc_with_snapshot(base_command, session_shell.as_ref(), &req.cwd);
|
||||
let command = if matches!(session_shell.shell_type, ShellType::PowerShell)
|
||||
&& ctx.session.features().enabled(Feature::PowershellUtf8)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user