Always fallback to real shell (#6953)

Either cmd.exe or `/bin/sh`.
This commit is contained in:
pakrym-oai
2025-11-20 10:58:46 -08:00
committed by GitHub
parent d909048a85
commit 30ca89424c
6 changed files with 234 additions and 131 deletions

View File

@@ -100,7 +100,7 @@ pub fn extract_bash_command(command: &[String]) -> Option<(&str, &str)> {
if !matches!(flag.as_str(), "-lc" | "-c") if !matches!(flag.as_str(), "-lc" | "-c")
|| !matches!( || !matches!(
detect_shell_type(&PathBuf::from(shell)), detect_shell_type(&PathBuf::from(shell)),
Some(ShellType::Zsh) | Some(ShellType::Bash) Some(ShellType::Zsh) | Some(ShellType::Bash) | Some(ShellType::Sh)
) )
{ {
return None; return None;

View File

@@ -493,7 +493,7 @@ impl Session {
// - load history metadata // - load history metadata
let rollout_fut = RolloutRecorder::new(&config, rollout_params); let rollout_fut = RolloutRecorder::new(&config, rollout_params);
let default_shell_fut = shell::default_user_shell(); let default_shell = shell::default_user_shell();
let history_meta_fut = crate::message_history::history_metadata(&config); let history_meta_fut = crate::message_history::history_metadata(&config);
let auth_statuses_fut = compute_auth_statuses( let auth_statuses_fut = compute_auth_statuses(
config.mcp_servers.iter(), config.mcp_servers.iter(),
@@ -501,12 +501,8 @@ impl Session {
); );
// Join all independent futures. // Join all independent futures.
let (rollout_recorder, default_shell, (history_log_id, history_entry_count), auth_statuses) = tokio::join!( let (rollout_recorder, (history_log_id, history_entry_count), auth_statuses) =
rollout_fut, tokio::join!(rollout_fut, history_meta_fut, auth_statuses_fut);
default_shell_fut,
history_meta_fut,
auth_statuses_fut
);
let rollout_recorder = rollout_recorder.map_err(|e| { let rollout_recorder = rollout_recorder.map_err(|e| {
error!("failed to initialize rollout recorder: {e:#}"); error!("failed to initialize rollout recorder: {e:#}");
@@ -1057,7 +1053,7 @@ impl Session {
Some(turn_context.cwd.clone()), Some(turn_context.cwd.clone()),
Some(turn_context.approval_policy), Some(turn_context.approval_policy),
Some(turn_context.sandbox_policy.clone()), Some(turn_context.sandbox_policy.clone()),
Some(self.user_shell().clone()), self.user_shell().clone(),
))); )));
items items
} }
@@ -2390,6 +2386,7 @@ mod tests {
use crate::config::ConfigOverrides; use crate::config::ConfigOverrides;
use crate::config::ConfigToml; use crate::config::ConfigToml;
use crate::exec::ExecToolCallOutput; use crate::exec::ExecToolCallOutput;
use crate::shell::default_user_shell;
use crate::tools::format_exec_output_str; use crate::tools::format_exec_output_str;
use crate::protocol::CompactedItem; use crate::protocol::CompactedItem;
@@ -2629,7 +2626,7 @@ mod tests {
unified_exec_manager: UnifiedExecSessionManager::default(), unified_exec_manager: UnifiedExecSessionManager::default(),
notifier: UserNotifier::new(None), notifier: UserNotifier::new(None),
rollout: Mutex::new(None), rollout: Mutex::new(None),
user_shell: shell::Shell::Unknown, user_shell: default_user_shell(),
show_raw_agent_reasoning: config.show_raw_agent_reasoning, show_raw_agent_reasoning: config.show_raw_agent_reasoning,
auth_manager: Arc::clone(&auth_manager), auth_manager: Arc::clone(&auth_manager),
otel_event_manager: otel_event_manager.clone(), otel_event_manager: otel_event_manager.clone(),
@@ -2707,7 +2704,7 @@ mod tests {
unified_exec_manager: UnifiedExecSessionManager::default(), unified_exec_manager: UnifiedExecSessionManager::default(),
notifier: UserNotifier::new(None), notifier: UserNotifier::new(None),
rollout: Mutex::new(None), rollout: Mutex::new(None),
user_shell: shell::Shell::Unknown, user_shell: default_user_shell(),
show_raw_agent_reasoning: config.show_raw_agent_reasoning, show_raw_agent_reasoning: config.show_raw_agent_reasoning,
auth_manager: Arc::clone(&auth_manager), auth_manager: Arc::clone(&auth_manager),
otel_event_manager: otel_event_manager.clone(), otel_event_manager: otel_event_manager.clone(),

View File

@@ -6,6 +6,7 @@ use crate::codex::TurnContext;
use crate::protocol::AskForApproval; use crate::protocol::AskForApproval;
use crate::protocol::SandboxPolicy; use crate::protocol::SandboxPolicy;
use crate::shell::Shell; use crate::shell::Shell;
use crate::shell::default_user_shell;
use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::SandboxMode;
use codex_protocol::models::ContentItem; use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem; use codex_protocol::models::ResponseItem;
@@ -28,7 +29,7 @@ pub(crate) struct EnvironmentContext {
pub sandbox_mode: Option<SandboxMode>, pub sandbox_mode: Option<SandboxMode>,
pub network_access: Option<NetworkAccess>, pub network_access: Option<NetworkAccess>,
pub writable_roots: Option<Vec<PathBuf>>, pub writable_roots: Option<Vec<PathBuf>>,
pub shell: Option<Shell>, pub shell: Shell,
} }
impl EnvironmentContext { impl EnvironmentContext {
@@ -36,7 +37,7 @@ impl EnvironmentContext {
cwd: Option<PathBuf>, cwd: Option<PathBuf>,
approval_policy: Option<AskForApproval>, approval_policy: Option<AskForApproval>,
sandbox_policy: Option<SandboxPolicy>, sandbox_policy: Option<SandboxPolicy>,
shell: Option<Shell>, shell: Shell,
) -> Self { ) -> Self {
Self { Self {
cwd, cwd,
@@ -110,7 +111,7 @@ impl EnvironmentContext {
} else { } else {
None None
}; };
EnvironmentContext::new(cwd, approval_policy, sandbox_policy, None) EnvironmentContext::new(cwd, approval_policy, sandbox_policy, default_user_shell())
} }
} }
@@ -121,7 +122,7 @@ impl From<&TurnContext> for EnvironmentContext {
Some(turn_context.approval_policy), Some(turn_context.approval_policy),
Some(turn_context.sandbox_policy.clone()), Some(turn_context.sandbox_policy.clone()),
// Shell is not configurable from turn to turn // Shell is not configurable from turn to turn
None, default_user_shell(),
) )
} }
} }
@@ -169,11 +170,9 @@ impl EnvironmentContext {
} }
lines.push(" </writable_roots>".to_string()); lines.push(" </writable_roots>".to_string());
} }
if let Some(shell) = self.shell
&& let Some(shell_name) = shell.name() let shell_name = self.shell.name();
{ lines.push(format!(" <shell>{shell_name}</shell>"));
lines.push(format!(" <shell>{shell_name}</shell>"));
}
lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string()); lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string());
lines.join("\n") lines.join("\n")
} }
@@ -193,12 +192,18 @@ impl From<EnvironmentContext> for ResponseItem {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::shell::BashShell; use crate::shell::ShellType;
use crate::shell::ZshShell;
use super::*; use super::*;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
fn fake_shell() -> Shell {
Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
}
}
fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy { fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy {
SandboxPolicy::WorkspaceWrite { SandboxPolicy::WorkspaceWrite {
writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(), writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(),
@@ -214,7 +219,7 @@ mod tests {
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest), Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo", "/tmp"], false)), Some(workspace_write_policy(vec!["/repo", "/tmp"], false)),
None, fake_shell(),
); );
let expected = r#"<environment_context> let expected = r#"<environment_context>
@@ -226,6 +231,7 @@ mod tests {
<root>/repo</root> <root>/repo</root>
<root>/tmp</root> <root>/tmp</root>
</writable_roots> </writable_roots>
<shell>bash</shell>
</environment_context>"#; </environment_context>"#;
assert_eq!(context.serialize_to_xml(), expected); assert_eq!(context.serialize_to_xml(), expected);
@@ -237,13 +243,14 @@ mod tests {
None, None,
Some(AskForApproval::Never), Some(AskForApproval::Never),
Some(SandboxPolicy::ReadOnly), Some(SandboxPolicy::ReadOnly),
None, fake_shell(),
); );
let expected = r#"<environment_context> let expected = r#"<environment_context>
<approval_policy>never</approval_policy> <approval_policy>never</approval_policy>
<sandbox_mode>read-only</sandbox_mode> <sandbox_mode>read-only</sandbox_mode>
<network_access>restricted</network_access> <network_access>restricted</network_access>
<shell>bash</shell>
</environment_context>"#; </environment_context>"#;
assert_eq!(context.serialize_to_xml(), expected); assert_eq!(context.serialize_to_xml(), expected);
@@ -255,13 +262,14 @@ mod tests {
None, None,
Some(AskForApproval::OnFailure), Some(AskForApproval::OnFailure),
Some(SandboxPolicy::DangerFullAccess), Some(SandboxPolicy::DangerFullAccess),
None, fake_shell(),
); );
let expected = r#"<environment_context> let expected = r#"<environment_context>
<approval_policy>on-failure</approval_policy> <approval_policy>on-failure</approval_policy>
<sandbox_mode>danger-full-access</sandbox_mode> <sandbox_mode>danger-full-access</sandbox_mode>
<network_access>enabled</network_access> <network_access>enabled</network_access>
<shell>bash</shell>
</environment_context>"#; </environment_context>"#;
assert_eq!(context.serialize_to_xml(), expected); assert_eq!(context.serialize_to_xml(), expected);
@@ -274,13 +282,13 @@ mod tests {
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest), Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)), Some(workspace_write_policy(vec!["/repo"], false)),
None, fake_shell(),
); );
let context2 = EnvironmentContext::new( let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::Never), Some(AskForApproval::Never),
Some(workspace_write_policy(vec!["/repo"], true)), Some(workspace_write_policy(vec!["/repo"], true)),
None, fake_shell(),
); );
assert!(!context1.equals_except_shell(&context2)); assert!(!context1.equals_except_shell(&context2));
} }
@@ -291,13 +299,13 @@ mod tests {
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest), Some(AskForApproval::OnRequest),
Some(SandboxPolicy::new_read_only_policy()), Some(SandboxPolicy::new_read_only_policy()),
None, fake_shell(),
); );
let context2 = EnvironmentContext::new( let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest), Some(AskForApproval::OnRequest),
Some(SandboxPolicy::new_workspace_write_policy()), Some(SandboxPolicy::new_workspace_write_policy()),
None, fake_shell(),
); );
assert!(!context1.equals_except_shell(&context2)); assert!(!context1.equals_except_shell(&context2));
@@ -309,13 +317,13 @@ mod tests {
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest), Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)), Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)),
None, fake_shell(),
); );
let context2 = EnvironmentContext::new( let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest), Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo", "/tmp"], true)), Some(workspace_write_policy(vec!["/repo", "/tmp"], true)),
None, fake_shell(),
); );
assert!(!context1.equals_except_shell(&context2)); assert!(!context1.equals_except_shell(&context2));
@@ -327,17 +335,19 @@ mod tests {
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest), Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)), Some(workspace_write_policy(vec!["/repo"], false)),
Some(Shell::Bash(BashShell { Shell {
shell_type: ShellType::Bash,
shell_path: "/bin/bash".into(), shell_path: "/bin/bash".into(),
})), },
); );
let context2 = EnvironmentContext::new( let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")), Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest), Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)), Some(workspace_write_policy(vec!["/repo"], false)),
Some(Shell::Zsh(ZshShell { Shell {
shell_type: ShellType::Zsh,
shell_path: "/bin/zsh".into(), shell_path: "/bin/zsh".into(),
})), },
); );
assert!(context1.equals_except_shell(&context2)); assert!(context1.equals_except_shell(&context2));

View File

@@ -7,61 +7,41 @@ pub enum ShellType {
Zsh, Zsh,
Bash, Bash,
PowerShell, PowerShell,
Sh,
Cmd,
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ZshShell { pub struct Shell {
pub(crate) shell_type: ShellType,
pub(crate) shell_path: PathBuf, pub(crate) shell_path: PathBuf,
} }
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct BashShell {
pub(crate) shell_path: PathBuf,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct PowerShellConfig {
pub(crate) shell_path: PathBuf, // Executable name or path, e.g. "pwsh" or "powershell.exe".
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum Shell {
Zsh(ZshShell),
Bash(BashShell),
PowerShell(PowerShellConfig),
Unknown,
}
impl Shell { impl Shell {
pub fn name(&self) -> Option<String> { pub fn name(&self) -> &'static str {
match self { match self.shell_type {
Shell::Zsh(ZshShell { shell_path, .. }) | Shell::Bash(BashShell { shell_path, .. }) => { ShellType::Zsh => "zsh",
std::path::Path::new(shell_path) ShellType::Bash => "bash",
.file_name() ShellType::PowerShell => "powershell",
.map(|s| s.to_string_lossy().to_string()) ShellType::Sh => "sh",
} ShellType::Cmd => "cmd",
Shell::PowerShell(ps) => ps
.shell_path
.file_stem()
.map(|s| s.to_string_lossy().to_string()),
Shell::Unknown => None,
} }
} }
/// Takes a string of shell and returns the full list of command args to /// Takes a string of shell and returns the full list of command args to
/// use with `exec()` to run the shell command. /// use with `exec()` to run the shell command.
pub fn derive_exec_args(&self, command: &str, use_login_shell: bool) -> Vec<String> { pub fn derive_exec_args(&self, command: &str, use_login_shell: bool) -> Vec<String> {
match self { match self.shell_type {
Shell::Zsh(ZshShell { shell_path, .. }) | Shell::Bash(BashShell { shell_path, .. }) => { ShellType::Zsh | ShellType::Bash | ShellType::Sh => {
let arg = if use_login_shell { "-lc" } else { "-c" }; let arg = if use_login_shell { "-lc" } else { "-c" };
vec![ vec![
shell_path.to_string_lossy().to_string(), self.shell_path.to_string_lossy().to_string(),
arg.to_string(), arg.to_string(),
command.to_string(), command.to_string(),
] ]
} }
Shell::PowerShell(ps) => { ShellType::PowerShell => {
let mut args = vec![ps.shell_path.to_string_lossy().to_string()]; let mut args = vec![self.shell_path.to_string_lossy().to_string()];
if !use_login_shell { if !use_login_shell {
args.push("-NoProfile".to_string()); args.push("-NoProfile".to_string());
} }
@@ -70,7 +50,12 @@ impl Shell {
args.push(command.to_string()); args.push(command.to_string());
args args
} }
Shell::Unknown => shlex::split(command).unwrap_or_else(|| vec![command.to_string()]), ShellType::Cmd => {
let mut args = vec![self.shell_path.to_string_lossy().to_string()];
args.push("/c".to_string());
args.push(command.to_string());
args
}
} }
} }
} }
@@ -143,19 +128,34 @@ fn get_shell_path(
None None
} }
fn get_zsh_shell(path: Option<&PathBuf>) -> Option<ZshShell> { fn get_zsh_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Zsh, path, "zsh", vec!["/bin/zsh"]); let shell_path = get_shell_path(ShellType::Zsh, path, "zsh", vec!["/bin/zsh"]);
shell_path.map(|shell_path| ZshShell { shell_path }) shell_path.map(|shell_path| Shell {
shell_type: ShellType::Zsh,
shell_path,
})
} }
fn get_bash_shell(path: Option<&PathBuf>) -> Option<BashShell> { fn get_bash_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Bash, path, "bash", vec!["/bin/bash"]); let shell_path = get_shell_path(ShellType::Bash, path, "bash", vec!["/bin/bash"]);
shell_path.map(|shell_path| BashShell { shell_path }) shell_path.map(|shell_path| Shell {
shell_type: ShellType::Bash,
shell_path,
})
} }
fn get_powershell_shell(path: Option<&PathBuf>) -> Option<PowerShellConfig> { fn get_sh_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Sh, path, "sh", vec!["/bin/sh"]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Sh,
shell_path,
})
}
fn get_powershell_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path( let shell_path = get_shell_path(
ShellType::PowerShell, ShellType::PowerShell,
path, path,
@@ -164,26 +164,56 @@ fn get_powershell_shell(path: Option<&PathBuf>) -> Option<PowerShellConfig> {
) )
.or_else(|| get_shell_path(ShellType::PowerShell, path, "powershell", vec![])); .or_else(|| get_shell_path(ShellType::PowerShell, path, "powershell", vec![]));
shell_path.map(|shell_path| PowerShellConfig { shell_path }) shell_path.map(|shell_path| Shell {
shell_type: ShellType::PowerShell,
shell_path,
})
}
fn get_cmd_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Cmd, path, "cmd", vec![]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Cmd,
shell_path,
})
}
fn ultimate_fallback_shell() -> Shell {
if cfg!(windows) {
Shell {
shell_type: ShellType::Cmd,
shell_path: PathBuf::from("cmd.exe"),
}
} else {
Shell {
shell_type: ShellType::Sh,
shell_path: PathBuf::from("/bin/sh"),
}
}
} }
pub fn get_shell_by_model_provided_path(shell_path: &PathBuf) -> Shell { pub fn get_shell_by_model_provided_path(shell_path: &PathBuf) -> Shell {
detect_shell_type(shell_path) detect_shell_type(shell_path)
.and_then(|shell_type| get_shell(shell_type, Some(shell_path))) .and_then(|shell_type| get_shell(shell_type, Some(shell_path)))
.unwrap_or(Shell::Unknown) .unwrap_or(ultimate_fallback_shell())
} }
pub fn get_shell(shell_type: ShellType, path: Option<&PathBuf>) -> Option<Shell> { pub fn get_shell(shell_type: ShellType, path: Option<&PathBuf>) -> Option<Shell> {
match shell_type { match shell_type {
ShellType::Zsh => get_zsh_shell(path).map(Shell::Zsh), ShellType::Zsh => get_zsh_shell(path),
ShellType::Bash => get_bash_shell(path).map(Shell::Bash), ShellType::Bash => get_bash_shell(path),
ShellType::PowerShell => get_powershell_shell(path).map(Shell::PowerShell), ShellType::PowerShell => get_powershell_shell(path),
ShellType::Sh => get_sh_shell(path),
ShellType::Cmd => get_cmd_shell(path),
} }
} }
pub fn detect_shell_type(shell_path: &PathBuf) -> Option<ShellType> { pub fn detect_shell_type(shell_path: &PathBuf) -> Option<ShellType> {
match shell_path.as_os_str().to_str() { match shell_path.as_os_str().to_str() {
Some("zsh") => Some(ShellType::Zsh), Some("zsh") => Some(ShellType::Zsh),
Some("sh") => Some(ShellType::Sh),
Some("cmd") => Some(ShellType::Cmd),
Some("bash") => Some(ShellType::Bash), Some("bash") => Some(ShellType::Bash),
Some("pwsh") => Some(ShellType::PowerShell), Some("pwsh") => Some(ShellType::PowerShell),
Some("powershell") => Some(ShellType::PowerShell), Some("powershell") => Some(ShellType::PowerShell),
@@ -200,11 +230,15 @@ pub fn detect_shell_type(shell_path: &PathBuf) -> Option<ShellType> {
} }
} }
pub async fn default_user_shell() -> Shell { pub fn default_user_shell() -> Shell {
default_user_shell_from_path(get_user_shell_path())
}
fn default_user_shell_from_path(user_shell_path: Option<PathBuf>) -> Shell {
if cfg!(windows) { if cfg!(windows) {
get_shell(ShellType::PowerShell, None).unwrap_or(Shell::Unknown) get_shell(ShellType::PowerShell, None).unwrap_or(ultimate_fallback_shell())
} else { } else {
let user_default_shell = get_user_shell_path() let user_default_shell = user_shell_path
.and_then(|shell| detect_shell_type(&shell)) .and_then(|shell| detect_shell_type(&shell))
.and_then(|shell_type| get_shell(shell_type, None)); .and_then(|shell_type| get_shell(shell_type, None));
@@ -218,7 +252,7 @@ pub async fn default_user_shell() -> Shell {
.or_else(|| get_shell(ShellType::Zsh, None)) .or_else(|| get_shell(ShellType::Zsh, None))
}; };
shell_with_fallback.unwrap_or(Shell::Unknown) shell_with_fallback.unwrap_or(ultimate_fallback_shell())
} }
} }
@@ -274,6 +308,19 @@ mod detect_shell_type_tests {
detect_shell_type(&PathBuf::from("/usr/local/bin/pwsh")), detect_shell_type(&PathBuf::from("/usr/local/bin/pwsh")),
Some(ShellType::PowerShell) Some(ShellType::PowerShell)
); );
assert_eq!(
detect_shell_type(&PathBuf::from("/bin/sh")),
Some(ShellType::Sh)
);
assert_eq!(detect_shell_type(&PathBuf::from("sh")), Some(ShellType::Sh));
assert_eq!(
detect_shell_type(&PathBuf::from("cmd")),
Some(ShellType::Cmd)
);
assert_eq!(
detect_shell_type(&PathBuf::from("cmd.exe")),
Some(ShellType::Cmd)
);
} }
} }
@@ -289,10 +336,17 @@ mod tests {
fn detects_zsh() { fn detects_zsh() {
let zsh_shell = get_shell(ShellType::Zsh, None).unwrap(); let zsh_shell = get_shell(ShellType::Zsh, None).unwrap();
let ZshShell { shell_path } = match zsh_shell { let shell_path = zsh_shell.shell_path;
Shell::Zsh(zsh_shell) => zsh_shell,
_ => panic!("expected zsh shell"), assert_eq!(shell_path, PathBuf::from("/bin/zsh"));
}; }
#[test]
#[cfg(target_os = "macos")]
fn fish_fallback_to_zsh() {
let zsh_shell = default_user_shell_from_path(Some(PathBuf::from("/bin/fish")));
let shell_path = zsh_shell.shell_path;
assert_eq!(shell_path, PathBuf::from("/bin/zsh")); assert_eq!(shell_path, PathBuf::from("/bin/zsh"));
} }
@@ -300,10 +354,7 @@ mod tests {
#[test] #[test]
fn detects_bash() { fn detects_bash() {
let bash_shell = get_shell(ShellType::Bash, None).unwrap(); let bash_shell = get_shell(ShellType::Bash, None).unwrap();
let BashShell { shell_path } = match bash_shell { let shell_path = bash_shell.shell_path;
Shell::Bash(bash_shell) => bash_shell,
_ => panic!("expected bash shell"),
};
assert!( assert!(
shell_path == PathBuf::from("/bin/bash") shell_path == PathBuf::from("/bin/bash")
@@ -312,6 +363,50 @@ mod tests {
); );
} }
#[test]
fn detects_sh() {
let sh_shell = get_shell(ShellType::Sh, None).unwrap();
let shell_path = sh_shell.shell_path;
assert!(
shell_path == PathBuf::from("/bin/sh") || shell_path == PathBuf::from("/usr/bin/sh"),
"shell path: {shell_path:?}",
);
}
#[test]
fn can_run_on_shell_test() {
let cmd = "echo \"Works\"";
if cfg!(windows) {
assert!(shell_works(
get_shell(ShellType::PowerShell, None),
"Out-String 'Works'",
true,
));
assert!(shell_works(get_shell(ShellType::Cmd, None), cmd, true,));
assert!(shell_works(Some(ultimate_fallback_shell()), cmd, true));
} else {
assert!(shell_works(Some(ultimate_fallback_shell()), cmd, true));
assert!(shell_works(get_shell(ShellType::Zsh, None), cmd, false));
assert!(shell_works(get_shell(ShellType::Bash, None), cmd, true));
assert!(shell_works(get_shell(ShellType::Sh, None), cmd, true));
}
}
fn shell_works(shell: Option<Shell>, command: &str, required: bool) -> bool {
if let Some(shell) = shell {
let args = shell.derive_exec_args(command, false);
let output = Command::new(args[0].clone())
.args(&args[1..])
.output()
.unwrap();
assert!(output.status.success());
assert!(String::from_utf8_lossy(&output.stdout).contains("Works"));
true
} else {
!required
}
}
#[tokio::test] #[tokio::test]
async fn test_current_shell_detects_zsh() { async fn test_current_shell_detects_zsh() {
let shell = Command::new("sh") let shell = Command::new("sh")
@@ -323,10 +418,11 @@ mod tests {
let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string(); let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
if shell_path.ends_with("/zsh") { if shell_path.ends_with("/zsh") {
assert_eq!( assert_eq!(
default_user_shell().await, default_user_shell(),
Shell::Zsh(ZshShell { Shell {
shell_type: ShellType::Zsh,
shell_path: PathBuf::from(shell_path), shell_path: PathBuf::from(shell_path),
}) }
); );
} }
} }
@@ -337,11 +433,8 @@ mod tests {
return; return;
} }
let powershell_shell = default_user_shell().await; let powershell_shell = default_user_shell();
let PowerShellConfig { shell_path } = match powershell_shell { let shell_path = powershell_shell.shell_path;
Shell::PowerShell(powershell_shell) => powershell_shell,
_ => panic!("expected powershell shell"),
};
assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe")); assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe"));
} }
@@ -353,10 +446,7 @@ mod tests {
} }
let powershell_shell = get_shell(ShellType::PowerShell, None).unwrap(); let powershell_shell = get_shell(ShellType::PowerShell, None).unwrap();
let PowerShellConfig { shell_path } = match powershell_shell { let shell_path = powershell_shell.shell_path;
Shell::PowerShell(powershell_shell) => powershell_shell,
_ => panic!("expected powershell shell"),
};
assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe")); assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe"));
} }

View File

@@ -338,29 +338,30 @@ mod tests {
use std::path::PathBuf; use std::path::PathBuf;
use crate::is_safe_command::is_known_safe_command; use crate::is_safe_command::is_known_safe_command;
use crate::shell::BashShell;
use crate::shell::PowerShellConfig;
use crate::shell::Shell; use crate::shell::Shell;
use crate::shell::ZshShell; use crate::shell::ShellType;
/// The logic for is_known_safe_command() has heuristics for known shells, /// The logic for is_known_safe_command() has heuristics for known shells,
/// so we must ensure the commands generated by [ShellCommandHandler] can be /// so we must ensure the commands generated by [ShellCommandHandler] can be
/// recognized as safe if the `command` is safe. /// recognized as safe if the `command` is safe.
#[test] #[test]
fn commands_generated_by_shell_command_handler_can_be_matched_by_is_known_safe_command() { fn commands_generated_by_shell_command_handler_can_be_matched_by_is_known_safe_command() {
let bash_shell = Shell::Bash(BashShell { let bash_shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"), shell_path: PathBuf::from("/bin/bash"),
}); };
assert_safe(&bash_shell, "ls -la"); assert_safe(&bash_shell, "ls -la");
let zsh_shell = Shell::Zsh(ZshShell { let zsh_shell = Shell {
shell_type: ShellType::Zsh,
shell_path: PathBuf::from("/bin/zsh"), shell_path: PathBuf::from("/bin/zsh"),
}); };
assert_safe(&zsh_shell, "ls -la"); assert_safe(&zsh_shell, "ls -la");
let powershell = Shell::PowerShell(PowerShellConfig { let powershell = Shell {
shell_type: ShellType::PowerShell,
shell_path: PathBuf::from("pwsh.exe"), shell_path: PathBuf::from("pwsh.exe"),
}); };
assert_safe(&powershell, "ls -Name"); assert_safe(&powershell, "ls -Name");
} }

View File

@@ -30,18 +30,15 @@ fn text_user_input(text: String) -> serde_json::Value {
} }
fn default_env_context_str(cwd: &str, shell: &Shell) -> String { fn default_env_context_str(cwd: &str, shell: &Shell) -> String {
let shell_name = shell.name();
format!( format!(
r#"<environment_context> r#"<environment_context>
<cwd>{}</cwd> <cwd>{cwd}</cwd>
<approval_policy>on-request</approval_policy> <approval_policy>on-request</approval_policy>
<sandbox_mode>read-only</sandbox_mode> <sandbox_mode>read-only</sandbox_mode>
<network_access>restricted</network_access> <network_access>restricted</network_access>
{}</environment_context>"#, <shell>{shell_name}</shell>
cwd, </environment_context>"#
match shell.name() {
Some(name) => format!(" <shell>{name}</shell>\n"),
None => String::new(),
}
) )
} }
@@ -227,7 +224,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
.await?; .await?;
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let shell = default_user_shell().await; let shell = default_user_shell();
let cwd_str = config.cwd.to_string_lossy(); let cwd_str = config.cwd.to_string_lossy();
let expected_env_text = default_env_context_str(&cwd_str, &shell); let expected_env_text = default_env_context_str(&cwd_str, &shell);
let expected_ui_text = format!( let expected_ui_text = format!(
@@ -345,6 +342,7 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() -> an
// After overriding the turn context, the environment context should be emitted again // After overriding the turn context, the environment context should be emitted again
// reflecting the new approval policy and sandbox settings. Omit cwd because it did // reflecting the new approval policy and sandbox settings. Omit cwd because it did
// not change. // not change.
let shell = default_user_shell();
let expected_env_text_2 = format!( let expected_env_text_2 = format!(
r#"<environment_context> r#"<environment_context>
<approval_policy>never</approval_policy> <approval_policy>never</approval_policy>
@@ -353,8 +351,10 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() -> an
<writable_roots> <writable_roots>
<root>{}</root> <root>{}</root>
</writable_roots> </writable_roots>
<shell>{}</shell>
</environment_context>"#, </environment_context>"#,
writable.path().to_string_lossy(), writable.path().display(),
shell.name()
); );
let expected_env_msg_2 = serde_json::json!({ let expected_env_msg_2 = serde_json::json!({
"type": "message", "type": "message",
@@ -522,6 +522,8 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() -> anyhow::Res
"role": "user", "role": "user",
"content": [ { "type": "input_text", "text": "hello 2" } ] "content": [ { "type": "input_text", "text": "hello 2" } ]
}); });
let shell = default_user_shell();
let expected_env_text_2 = format!( let expected_env_text_2 = format!(
r#"<environment_context> r#"<environment_context>
<cwd>{}</cwd> <cwd>{}</cwd>
@@ -531,9 +533,11 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() -> anyhow::Res
<writable_roots> <writable_roots>
<root>{}</root> <root>{}</root>
</writable_roots> </writable_roots>
<shell>{}</shell>
</environment_context>"#, </environment_context>"#,
new_cwd.path().to_string_lossy(), new_cwd.path().display(),
writable.path().to_string_lossy(), writable.path().display(),
shell.name(),
); );
let expected_env_msg_2 = serde_json::json!({ let expected_env_msg_2 = serde_json::json!({
"type": "message", "type": "message",
@@ -610,7 +614,7 @@ async fn send_user_turn_with_no_changes_does_not_send_environment_context() -> a
let body1 = req1.single_request().body_json(); let body1 = req1.single_request().body_json();
let body2 = req2.single_request().body_json(); let body2 = req2.single_request().body_json();
let shell = default_user_shell().await; let shell = default_user_shell();
let default_cwd_lossy = default_cwd.to_string_lossy(); let default_cwd_lossy = default_cwd.to_string_lossy();
let expected_ui_text = format!( let expected_ui_text = format!(
"# AGENTS.md instructions for {default_cwd_lossy}\n\n<INSTRUCTIONS>\nbe consistent and helpful\n</INSTRUCTIONS>" "# AGENTS.md instructions for {default_cwd_lossy}\n\n<INSTRUCTIONS>\nbe consistent and helpful\n</INSTRUCTIONS>"
@@ -697,7 +701,7 @@ async fn send_user_turn_with_changes_sends_environment_context() -> anyhow::Resu
let body1 = req1.single_request().body_json(); let body1 = req1.single_request().body_json();
let body2 = req2.single_request().body_json(); let body2 = req2.single_request().body_json();
let shell = default_user_shell().await; let shell = default_user_shell();
let expected_ui_text = format!( let expected_ui_text = format!(
"# AGENTS.md instructions for {}\n\n<INSTRUCTIONS>\nbe consistent and helpful\n</INSTRUCTIONS>", "# AGENTS.md instructions for {}\n\n<INSTRUCTIONS>\nbe consistent and helpful\n</INSTRUCTIONS>",
default_cwd.to_string_lossy() default_cwd.to_string_lossy()
@@ -717,14 +721,15 @@ async fn send_user_turn_with_changes_sends_environment_context() -> anyhow::Resu
]); ]);
assert_eq!(body1["input"], expected_input_1); assert_eq!(body1["input"], expected_input_1);
let expected_env_msg_2 = text_user_input( let shell_name = shell.name();
let expected_env_msg_2 = text_user_input(format!(
r#"<environment_context> r#"<environment_context>
<approval_policy>never</approval_policy> <approval_policy>never</approval_policy>
<sandbox_mode>danger-full-access</sandbox_mode> <sandbox_mode>danger-full-access</sandbox_mode>
<network_access>enabled</network_access> <network_access>enabled</network_access>
<shell>{shell_name}</shell>
</environment_context>"# </environment_context>"#
.to_string(), ));
);
let expected_user_message_2 = text_user_input("hello 2".to_string()); let expected_user_message_2 = text_user_input("hello 2".to_string());
let expected_input_2 = serde_json::Value::Array(vec![ let expected_input_2 = serde_json::Value::Array(vec![
expected_ui_msg, expected_ui_msg,