Compare commits

...

4 Commits

2 changed files with 100 additions and 79 deletions

View File

@@ -1396,11 +1396,11 @@ fn parse_container_exec_arguments(
fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams {
if sess.shell_environment_policy.use_profile {
let command = sess
let wrapped_params = sess
.user_shell
.format_default_shell_invocation(params.command.clone());
if let Some(command) = command {
return ExecParams { command, ..params };
.format_default_shell_invocation(params.clone());
if let Some(wrapped_params) = wrapped_params {
return wrapped_params;
}
}
params

View File

@@ -1,9 +1,14 @@
use std::collections::HashMap;
use shlex;
use tokio::process::Command;
use crate::exec::ExecParams;
#[derive(Debug, PartialEq, Eq)]
pub struct ZshShell {
shell_path: String,
zshrc_path: String,
env: HashMap<String, String>,
}
#[derive(Debug, PartialEq, Eq)]
@@ -13,25 +18,30 @@ pub enum Shell {
}
impl Shell {
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
pub fn format_default_shell_invocation(&self, params: ExecParams) -> Option<ExecParams> {
match self {
Shell::Zsh(zsh) => {
if !std::path::Path::new(&zsh.zshrc_path).exists() {
return None;
}
let mut result = vec![zsh.shell_path.clone()];
result.push("-lc".to_string());
let mut result = vec![zsh.shell_path.clone(), "-c".to_string()];
let command = params.command;
let joined = strip_bash_lc(&command)
.or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok());
if let Some(joined) = joined {
result.push(format!("source {} && ({joined})", zsh.zshrc_path));
result.push(format!("({joined})"));
} else {
return None;
}
Some(result)
Some(ExecParams {
command: result,
env: {
let mut env = params.env.clone();
env.extend(zsh.env.clone());
env
},
..params
})
}
Shell::Unknown => None,
}
@@ -54,6 +64,7 @@ fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
#[cfg(target_os = "macos")]
pub async fn default_user_shell() -> Shell {
use tokio::process::Command;
use tracing::warn;
use whoami;
let user = whoami::username();
@@ -72,9 +83,27 @@ pub async fn default_user_shell() -> Shell {
for line in stdout.lines() {
if let Some(shell_path) = line.strip_prefix("UserShell: ") {
if shell_path.ends_with("/zsh") {
let zshrc_path = format!("{home}/.zshrc");
let mut collect_env_args = vec!["-lc".to_string()];
if std::path::Path::new(&zshrc_path).exists() {
collect_env_args
.push(format!("source {zshrc_path} >/dev/null 2>&1; printenv"));
} else {
collect_env_args.push("printenv".to_string());
}
let env = match collect_env(shell_path, collect_env_args).await {
Ok(env) => env,
Err(e) => {
warn!("Failed to collect env: {e}");
HashMap::new()
}
};
return Shell::Zsh(ZshShell {
shell_path: shell_path.to_string(),
zshrc_path: format!("{home}/.zshrc"),
env,
});
}
}
@@ -86,6 +115,25 @@ pub async fn default_user_shell() -> Shell {
}
}
async fn collect_env(
command: &str,
args: Vec<String>,
) -> Result<HashMap<String, String>, std::io::Error> {
let output = Command::new(command)
.args(args)
.env_clear()
.output()
.await?;
let mut env = HashMap::new();
for line in String::from_utf8_lossy(&output.stdout).lines() {
let parts: Vec<&str> = line.splitn(2, '=').collect();
if parts.len() == 2 {
env.insert(parts[0].to_string(), parts[1].to_string());
}
}
Ok(env)
}
#[cfg(not(target_os = "macos"))]
pub async fn default_user_shell() -> Shell {
Shell::Unknown
@@ -106,27 +154,22 @@ mod tests {
.output()
.unwrap();
let home = std::env::var("HOME").unwrap();
let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
if shell_path.ends_with("/zsh") {
assert_eq!(
default_user_shell().await,
Shell::Zsh(ZshShell {
shell_path: shell_path.to_string(),
zshrc_path: format!("{home}/.zshrc",),
})
);
}
}
let shell = default_user_shell().await;
#[tokio::test]
async fn test_run_with_profile_zshrc_not_exists() {
let shell = Shell::Zsh(ZshShell {
shell_path: "/bin/zsh".to_string(),
zshrc_path: "/does/not/exist/.zshrc".to_string(),
});
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
assert_eq!(actual_cmd, None);
if let Shell::Zsh(ZshShell {
shell_path: actual_shell_path,
env,
}) = shell
{
assert_eq!(actual_shell_path, shell_path);
assert!(env.contains_key("PATH"));
assert!(env.contains_key("HOME"));
} else {
panic!("Expected Zsh shell, got {shell:?}");
}
}
}
#[expect(clippy::unwrap_used)]
@@ -136,31 +179,22 @@ mod tests {
let cases = vec![
(
vec!["myecho"],
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
Some("It works!\n"),
),
(
vec!["myecho"],
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
Some("It works!\n"),
vec!["bash", "-lc", "echo $MY_VAR"],
vec![shell_path, "-c", "(echo $MY_VAR)"],
Some("123\n"),
),
(
vec!["bash", "-c", "echo 'single' \"double\""],
vec![
shell_path,
"-lc",
"source ZSHRC_PATH && (bash -c \"echo 'single' \\\"double\\\"\")",
"-c",
"(bash -c \"echo 'single' \\\"double\\\"\")",
],
Some("single double\n"),
),
(
vec!["bash", "-lc", "echo 'single' \"double\""],
vec![
shell_path,
"-lc",
"source ZSHRC_PATH && (echo 'single' \"double\")",
],
vec![shell_path, "-c", "(echo 'single' \"double\")"],
Some("single double\n"),
),
];
@@ -176,46 +210,33 @@ mod tests {
use crate::exec::process_exec_tool_call;
use crate::protocol::SandboxPolicy;
// create a temp directory with a zshrc file in it
let temp_home = tempfile::tempdir().unwrap();
let zshrc_path = temp_home.path().join(".zshrc");
std::fs::write(
&zshrc_path,
r#"
set -x
function myecho {
echo 'It works!'
}
"#,
)
.unwrap();
let shell = Shell::Zsh(ZshShell {
shell_path: shell_path.to_string(),
zshrc_path: zshrc_path.to_str().unwrap().to_string(),
env: HashMap::from([("MY_VAR".to_string(), "123".to_string())]),
});
let actual_cmd = shell
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
let expected_cmd = expected_cmd
.iter()
.map(|s| {
s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
.to_string()
.format_default_shell_invocation(ExecParams {
command: input.iter().map(|s| s.to_string()).collect(),
cwd: PathBuf::from("/"),
timeout_ms: None,
env: HashMap::from([("MY_OTHER_VAR".to_string(), "456".to_string())]),
})
.collect();
.unwrap();
assert_eq!(actual_cmd, Some(expected_cmd));
let expected_cmd = expected_cmd.clone();
assert_eq!(actual_cmd.command, expected_cmd);
assert_eq!(
actual_cmd.env,
HashMap::from([
("MY_VAR".to_string(), "123".to_string()),
("MY_OTHER_VAR".to_string(), "456".to_string()),
])
);
// Actually run the command and check output/exit code
let output = process_exec_tool_call(
ExecParams {
command: actual_cmd.unwrap(),
cwd: PathBuf::from(temp_home.path()),
timeout_ms: None,
env: HashMap::from([(
"HOME".to_string(),
temp_home.path().to_str().unwrap().to_string(),
)]),
},
actual_cmd,
SandboxType::None,
Arc::new(Notify::new()),
&SandboxPolicy::DangerFullAccess,