Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Bolin
62a42c31a7 Support zsh shell tool with shell-escalation 2026-02-23 09:52:41 -08:00
Michael Bolin
86bfa68e42 Use Arc-based ToolCtx in tool runtimes 2026-02-23 09:52:41 -08:00
32 changed files with 1647 additions and 701 deletions

26
codex-rs/Cargo.lock generated
View File

@@ -1408,6 +1408,7 @@ dependencies = [
"anyhow",
"codex-apply-patch",
"codex-linux-sandbox",
"codex-shell-escalation",
"codex-utils-home-dir",
"dotenvy",
"tempfile",
@@ -1653,6 +1654,7 @@ dependencies = [
"codex-rmcp-client",
"codex-secrets",
"codex-shell-command",
"codex-shell-escalation",
"codex-skills",
"codex-state",
"codex-utils-absolute-path",
@@ -2199,6 +2201,7 @@ dependencies = [
"which",
]
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
[[package]]
name = "codex-shell-escalation"
version = "0.0.0"
@@ -2220,6 +2223,29 @@ dependencies = [
"tracing",
]
||||||| base
=======
[[package]]
name = "codex-shell-escalation"
version = "0.0.0"
dependencies = [
"anyhow",
"async-trait",
"codex-execpolicy",
"codex-protocol",
"libc",
"path-absolutize",
"pretty_assertions",
"serde",
"serde_json",
"socket2 0.6.2",
"tempfile",
"tokio",
"tokio-util",
"tracing",
]
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
[[package]]
name = "codex-skills"
version = "0.0.0"

View File

@@ -24,11 +24,6 @@ struct AppServerArgs {
fn main() -> anyhow::Result<()> {
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
// Run wrapper mode only after arg0 dispatch so `codex-linux-sandbox`
// invocations don't get misclassified as zsh exec-wrapper calls.
if codex_core::maybe_run_zsh_exec_wrapper_mode()? {
return Ok(());
}
let args = AppServerArgs::parse();
let managed_config_path = managed_config_path_from_debug_env();
let loader_overrides = LoaderOverrides {

View File

@@ -19,3 +19,6 @@ codex-utils-home-dir = { workspace = true }
dotenvy = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread"] }
[target.'cfg(unix)'.dependencies]
codex-shell-escalation = { workspace = true }

View File

@@ -12,6 +12,8 @@ use tempfile::TempDir;
const LINUX_SANDBOX_ARG0: &str = "codex-linux-sandbox";
const APPLY_PATCH_ARG0: &str = "apply_patch";
const MISSPELLED_APPLY_PATCH_ARG0: &str = "applypatch";
#[cfg(unix)]
const EXECVE_WRAPPER_ARG0: &str = "codex-execve-wrapper";
const LOCK_FILENAME: &str = ".lock";
const TOKIO_WORKER_STACK_SIZE_BYTES: usize = 16 * 1024 * 1024;
@@ -39,6 +41,32 @@ pub fn arg0_dispatch() -> Option<Arg0PathEntryGuard> {
.and_then(|s| s.to_str())
.unwrap_or("");
#[cfg(unix)]
if exe_name == EXECVE_WRAPPER_ARG0 {
let mut args = std::env::args();
let _ = args.next();
let file = match args.next() {
Some(file) => file,
None => std::process::exit(1),
};
let argv = args.collect::<Vec<_>>();
let runtime = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(runtime) => runtime,
Err(_) => std::process::exit(1),
};
let exit_code = runtime.block_on(codex_shell_escalation::unix::escalate_client::run(
file, argv,
));
match exit_code {
Ok(exit_code) => std::process::exit(exit_code),
Err(_) => std::process::exit(1),
}
}
if exe_name == LINUX_SANDBOX_ARG0 {
// Safety: [`run_main`] never returns.
codex_linux_sandbox::run_main();
@@ -227,6 +255,8 @@ pub fn prepend_path_entry_for_codex_aliases() -> std::io::Result<Arg0PathEntryGu
MISSPELLED_APPLY_PATCH_ARG0,
#[cfg(target_os = "linux")]
LINUX_SANDBOX_ARG0,
#[cfg(unix)]
EXECVE_WRAPPER_ARG0,
] {
let exe = std::env::current_exe()?;

View File

@@ -544,11 +544,6 @@ fn stage_str(stage: codex_core::features::Stage) -> &'static str {
fn main() -> anyhow::Result<()> {
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
// Run wrapper mode only after arg0 dispatch so `codex-linux-sandbox`
// invocations don't get misclassified as zsh exec-wrapper calls.
if codex_core::maybe_run_zsh_exec_wrapper_mode()? {
return Ok(());
}
cli_main(codex_linux_sandbox_exe).await?;
Ok(())
})

View File

@@ -138,6 +138,9 @@ windows-sys = { version = "0.52", features = [
[target.'cfg(any(target_os = "freebsd", target_os = "openbsd"))'.dependencies]
keyring = { workspace = true, features = ["sync-secret-service"] }
[target.'cfg(unix)'.dependencies]
codex-shell-escalation = { workspace = true }
[dev-dependencies]
assert_cmd = { workspace = true }
assert_matches = { workspace = true }

View File

@@ -249,7 +249,6 @@ use crate::turn_diff_tracker::TurnDiffTracker;
use crate::unified_exec::UnifiedExecProcessManager;
use crate::util::backoff;
use crate::windows_sandbox::WindowsSandboxLevelExt;
use crate::zsh_exec_bridge::ZshExecBridge;
use codex_async_utils::OrCancelExt;
use codex_otel::OtelManager;
use codex_otel::TelemetryAuthMode;
@@ -1208,7 +1207,8 @@ impl Session {
"zsh fork feature enabled, but `zsh_path` is not configured; set `zsh_path` in config.toml"
)
})?;
shell::get_shell(shell::ShellType::Zsh, Some(zsh_path)).ok_or_else(|| {
let zsh_path = zsh_path.to_path_buf();
shell::get_shell(shell::ShellType::Zsh, Some(&zsh_path)).ok_or_else(|| {
anyhow::anyhow!(
"zsh fork feature enabled, but zsh_path `{}` is not usable; set `zsh_path` to a valid zsh executable",
zsh_path.display()
@@ -1287,12 +1287,6 @@ impl Session {
(None, None)
};
let zsh_exec_bridge =
ZshExecBridge::new(config.zsh_path.clone(), config.codex_home.clone());
zsh_exec_bridge
.initialize_for_session(&conversation_id.to_string())
.await;
let services = SessionServices {
// Initialize the MCP connection manager with an uninitialized
// instance. It will be replaced with one created via
@@ -1308,7 +1302,7 @@ impl Session {
unified_exec_manager: UnifiedExecProcessManager::new(
config.background_terminal_max_timeout,
),
zsh_exec_bridge,
shell_zsh_path: config.zsh_path.clone(),
analytics_events_client: AnalyticsEventsClient::new(
Arc::clone(&config),
Arc::clone(&auth_manager),
@@ -4227,7 +4221,6 @@ mod handlers {
.unified_exec_manager
.terminate_all_processes()
.await;
sess.services.zsh_exec_bridge.shutdown().await;
info!("Shutting down Codex instance");
let history = sess.clone_history().await;
let turn_count = history
@@ -7895,7 +7888,7 @@ mod tests {
unified_exec_manager: UnifiedExecProcessManager::new(
config.background_terminal_max_timeout,
),
zsh_exec_bridge: ZshExecBridge::default(),
shell_zsh_path: None,
analytics_events_client: AnalyticsEventsClient::new(
Arc::clone(&config),
Arc::clone(&auth_manager),
@@ -8048,7 +8041,7 @@ mod tests {
unified_exec_manager: UnifiedExecProcessManager::new(
config.background_terminal_max_timeout,
),
zsh_exec_bridge: ZshExecBridge::default(),
shell_zsh_path: None,
analytics_events_client: AnalyticsEventsClient::new(
Arc::clone(&config),
Arc::clone(&auth_manager),

View File

@@ -372,7 +372,7 @@ pub struct Config {
pub js_repl_node_module_dirs: Vec<PathBuf>,
/// Optional absolute path to patched zsh used by zsh-exec-bridge-backed shell execution.
pub zsh_path: Option<PathBuf>,
pub zsh_path: Option<AbsolutePathBuf>,
/// Value to use for `reasoning.effort` when making a request using the
/// Responses API.
@@ -1484,7 +1484,7 @@ pub struct ConfigOverrides {
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub js_repl_node_path: Option<PathBuf>,
pub js_repl_node_module_dirs: Option<Vec<PathBuf>>,
pub zsh_path: Option<PathBuf>,
pub zsh_path: Option<AbsolutePathBuf>,
pub base_instructions: Option<String>,
pub developer_instructions: Option<String>,
pub personality: Option<Personality>,
@@ -1905,8 +1905,8 @@ impl Config {
})
.unwrap_or_default();
let zsh_path = zsh_path_override
.or(config_profile.zsh_path.map(Into::into))
.or(cfg.zsh_path.map(Into::into));
.or(config_profile.zsh_path)
.or(cfg.zsh_path);
let review_model = override_review_model.or(cfg.review_model);

View File

@@ -109,7 +109,6 @@ pub mod terminal;
mod tools;
pub mod turn_diff_tracker;
mod turn_metadata;
mod zsh_exec_bridge;
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
pub use rollout::INTERACTIVE_SESSION_SOURCES;
pub use rollout::RolloutRecorder;
@@ -144,7 +143,17 @@ pub(crate) use codex_shell_command::is_safe_command;
pub(crate) use codex_shell_command::parse_command;
pub(crate) use codex_shell_command::powershell;
pub use client::ModelClient;
pub use client::ModelClientSession;
pub use client::ResponsesWebsocketVersion;
pub use client::X_CODEX_TURN_METADATA_HEADER;
pub use client::ws_version_from_features;
pub use client_common::Prompt;
pub use client_common::REVIEW_PROMPT;
pub use client_common::ResponseEvent;
pub use client_common::ResponseStream;
pub use compact::content_items_to_text;
pub use event_mapping::parse_turn_item;
pub use exec_policy::ExecPolicyError;
pub use exec_policy::check_execpolicy_for_warnings;
pub use exec_policy::format_exec_policy_error_with_source;
@@ -153,18 +162,6 @@ pub use file_watcher::FileWatcherEvent;
pub use safety::get_platform_sandbox;
pub use tools::spec::parse_tool_input_schema;
pub use turn_metadata::build_turn_metadata_header;
pub use zsh_exec_bridge::maybe_run_zsh_exec_wrapper_mode;
pub use client::ModelClient;
pub use client::ModelClientSession;
pub use client::ResponsesWebsocketVersion;
pub use client::ws_version_from_features;
pub use client_common::Prompt;
pub use client_common::REVIEW_PROMPT;
pub use client_common::ResponseEvent;
pub use client_common::ResponseStream;
pub use compact::content_items_to_text;
pub use event_mapping::parse_turn_item;
pub mod compact;
pub mod memory_trace;
pub mod otel_init;

View File

@@ -163,19 +163,13 @@ impl SandboxManager {
SandboxType::MacosSeatbelt => {
let mut seatbelt_env = HashMap::new();
seatbelt_env.insert(CODEX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
let zsh_exec_bridge_wrapper_socket = env
.get(crate::zsh_exec_bridge::ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR)
.map(PathBuf::from);
let zsh_exec_bridge_allowed_unix_sockets = zsh_exec_bridge_wrapper_socket
.as_ref()
.map_or_else(Vec::new, |path| vec![path.clone()]);
let mut args = create_seatbelt_command_args(
command.clone(),
policy,
sandbox_policy_cwd,
enforce_managed_network,
network,
&zsh_exec_bridge_allowed_unix_sockets,
&[],
);
let mut full_command = Vec::with_capacity(1 + args.len());
full_command.push(MACOS_PATH_TO_SEATBELT_EXECUTABLE.to_string());

View File

@@ -15,9 +15,9 @@ use crate::state_db::StateDbHandle;
use crate::tools::network_approval::NetworkApprovalService;
use crate::tools::sandboxing::ApprovalStore;
use crate::unified_exec::UnifiedExecProcessManager;
use crate::zsh_exec_bridge::ZshExecBridge;
use codex_hooks::Hooks;
use codex_otel::OtelManager;
use codex_utils_absolute_path::AbsolutePathBuf;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::watch;
@@ -27,7 +27,7 @@ pub(crate) struct SessionServices {
pub(crate) mcp_connection_manager: Arc<RwLock<McpConnectionManager>>,
pub(crate) mcp_startup_cancellation_token: Mutex<CancellationToken>,
pub(crate) unified_exec_manager: UnifiedExecProcessManager,
pub(crate) zsh_exec_bridge: ZshExecBridge,
pub(crate) shell_zsh_path: Option<AbsolutePathBuf>,
pub(crate) analytics_events_client: AnalyticsEventsClient,
pub(crate) hooks: Hooks,
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,

View File

@@ -31,6 +31,7 @@ use async_trait::async_trait;
use codex_apply_patch::ApplyPatchAction;
use codex_apply_patch::ApplyPatchFileChange;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::sync::Arc;
pub struct ApplyPatchHandler;
@@ -139,8 +140,8 @@ impl ToolHandler for ApplyPatchHandler {
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ApplyPatchRuntime::new();
let tool_ctx = ToolCtx {
session: session.as_ref(),
turn: turn.as_ref(),
session: session.clone(),
turn: turn.clone(),
call_id: call_id.clone(),
tool_name: tool_name.to_string(),
};
@@ -149,7 +150,7 @@ impl ToolHandler for ApplyPatchHandler {
&mut runtime,
&req,
&tool_ctx,
&turn,
turn.as_ref(),
turn.approval_policy.value(),
)
.await
@@ -193,8 +194,8 @@ pub(crate) async fn intercept_apply_patch(
command: &[String],
cwd: &Path,
timeout_ms: Option<u64>,
session: &Session,
turn: &TurnContext,
session: Arc<Session>,
turn: Arc<TurnContext>,
tracker: Option<&SharedTurnDiffTracker>,
call_id: &str,
tool_name: &str,
@@ -203,11 +204,13 @@ pub(crate) async fn intercept_apply_patch(
codex_apply_patch::MaybeApplyPatchVerified::Body(changes) => {
session
.record_model_warning(
format!("apply_patch was requested via {tool_name}. Use the apply_patch tool instead of exec_command."),
turn,
format!(
"apply_patch was requested via {tool_name}. Use the apply_patch tool instead of exec_command."
),
turn.as_ref(),
)
.await;
match apply_patch::apply_patch(turn, changes).await {
match apply_patch::apply_patch(turn.as_ref(), changes).await {
InternalApplyPatchInvocation::Output(item) => {
let content = item?;
Ok(Some(ToolOutput::Function {
@@ -219,8 +222,12 @@ pub(crate) async fn intercept_apply_patch(
let changes = convert_apply_patch_to_protocol(&apply.action);
let approval_keys = file_paths_for_action(&apply.action);
let emitter = ToolEmitter::apply_patch(changes.clone(), apply.auto_approved);
let event_ctx =
ToolEventCtx::new(session, turn, call_id, tracker.as_ref().copied());
let event_ctx = ToolEventCtx::new(
session.as_ref(),
turn.as_ref(),
call_id,
tracker.as_ref().copied(),
);
emitter.begin(event_ctx).await;
let req = ApplyPatchRequest {
@@ -235,8 +242,8 @@ pub(crate) async fn intercept_apply_patch(
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ApplyPatchRuntime::new();
let tool_ctx = ToolCtx {
session,
turn,
session: session.clone(),
turn: turn.clone(),
call_id: call_id.to_string(),
tool_name: tool_name.to_string(),
};
@@ -245,13 +252,17 @@ pub(crate) async fn intercept_apply_patch(
&mut runtime,
&req,
&tool_ctx,
turn,
turn.as_ref(),
turn.approval_policy.value(),
)
.await
.map(|result| result.output);
let event_ctx =
ToolEventCtx::new(session, turn, call_id, tracker.as_ref().copied());
let event_ctx = ToolEventCtx::new(
session.as_ref(),
turn.as_ref(),
call_id,
tracker.as_ref().copied(),
);
let content = emitter.finish(event_ctx, out).await?;
Ok(Some(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),

View File

@@ -296,8 +296,8 @@ impl ShellHandler {
&exec_params.command,
&exec_params.cwd,
exec_params.expiration.timeout_ms(),
session.as_ref(),
turn.as_ref(),
session.clone(),
turn.clone(),
Some(&tracker),
&call_id,
tool_name.as_str(),
@@ -343,8 +343,8 @@ impl ShellHandler {
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ShellRuntime::new();
let tool_ctx = ToolCtx {
session: session.as_ref(),
turn: turn.as_ref(),
session: session.clone(),
turn: turn.clone(),
call_id: call_id.clone(),
tool_name,
};

View File

@@ -172,8 +172,8 @@ impl ToolHandler for UnifiedExecHandler {
&command,
&cwd,
Some(yield_time_ms),
context.session.as_ref(),
context.turn.as_ref(),
context.session.clone(),
context.turn.clone(),
Some(&tracker),
&context.call_id,
tool_name.as_str(),

View File

@@ -48,7 +48,7 @@ impl ToolOrchestrator {
async fn run_attempt<Rq, Out, T>(
tool: &mut T,
req: &Rq,
tool_ctx: &ToolCtx<'_>,
tool_ctx: &ToolCtx,
attempt: &SandboxAttempt<'_>,
has_managed_network_requirements: bool,
) -> (Result<Out, ToolError>, Option<DeferredNetworkApproval>)
@@ -56,7 +56,7 @@ impl ToolOrchestrator {
T: ToolRuntime<Rq, Out>,
{
let network_approval = begin_network_approval(
tool_ctx.session,
&tool_ctx.session,
&tool_ctx.turn.sub_id,
&tool_ctx.call_id,
has_managed_network_requirements,
@@ -65,8 +65,8 @@ impl ToolOrchestrator {
.await;
let attempt_tool_ctx = ToolCtx {
session: tool_ctx.session,
turn: tool_ctx.turn,
session: tool_ctx.session.clone(),
turn: tool_ctx.turn.clone(),
call_id: tool_ctx.call_id.clone(),
tool_name: tool_ctx.tool_name.clone(),
};
@@ -79,7 +79,7 @@ impl ToolOrchestrator {
match network_approval.mode() {
NetworkApprovalMode::Immediate => {
let finalize_result =
finish_immediate_network_approval(tool_ctx.session, network_approval).await;
finish_immediate_network_approval(&tool_ctx.session, network_approval).await;
if let Err(err) = finalize_result {
return (Err(err), None);
}
@@ -88,7 +88,7 @@ impl ToolOrchestrator {
NetworkApprovalMode::Deferred => {
let deferred = network_approval.into_deferred();
if run_result.is_err() {
finish_deferred_network_approval(tool_ctx.session, deferred).await;
finish_deferred_network_approval(&tool_ctx.session, deferred).await;
return (run_result, None);
}
(run_result, deferred)
@@ -100,7 +100,7 @@ impl ToolOrchestrator {
&mut self,
tool: &mut T,
req: &Rq,
tool_ctx: &ToolCtx<'_>,
tool_ctx: &ToolCtx,
turn_ctx: &crate::codex::TurnContext,
approval_policy: AskForApproval,
) -> Result<OrchestratorRunResult<Out>, ToolError>
@@ -128,7 +128,7 @@ impl ToolOrchestrator {
}
ExecApprovalRequirement::NeedsApproval { reason, .. } => {
let approval_ctx = ApprovalCtx {
session: tool_ctx.session,
session: &tool_ctx.session,
turn: turn_ctx,
call_id: &tool_ctx.call_id,
retry_reason: reason,
@@ -256,7 +256,7 @@ impl ToolOrchestrator {
&& network_approval_context.is_none();
if !bypass_retry_approval {
let approval_ctx = ApprovalCtx {
session: tool_ctx.session,
session: &tool_ctx.session,
turn: turn_ctx,
call_id: &tool_ctx.call_id,
retry_reason: Some(retry_reason),

View File

@@ -70,7 +70,7 @@ impl ApplyPatchRuntime {
})
}
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
fn stdout_stream(ctx: &ToolCtx) -> Option<crate::exec::StdoutStream> {
Some(crate::exec::StdoutStream {
sub_id: ctx.turn.sub_id.clone(),
call_id: ctx.call_id.clone(),
@@ -156,7 +156,7 @@ impl ToolRuntime<ApplyPatchRequest, ExecToolCallOutput> for ApplyPatchRuntime {
&mut self,
req: &ApplyPatchRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx<'_>,
ctx: &ToolCtx,
) -> Result<ExecToolCallOutput, ToolError> {
let spec = Self::build_command_spec(req)?;
let env = attempt

View File

@@ -5,7 +5,11 @@ Executes shell requests under the orchestrator: asks for approval when needed,
builds a CommandSpec, and runs it under the current SandboxAttempt.
*/
use crate::command_canonicalization::canonicalize_command_for_approval;
use crate::error::CodexErr;
use crate::error::SandboxErr;
use crate::exec::ExecToolCallOutput;
use crate::exec::SandboxType;
use crate::exec::is_likely_sandbox_denied;
use crate::features::Feature;
use crate::powershell::prefix_powershell_script_with_utf8;
use crate::sandboxing::SandboxPermissions;
@@ -26,19 +30,56 @@ use crate::tools::sandboxing::ToolCtx;
use crate::tools::sandboxing::ToolError;
use crate::tools::sandboxing::ToolRuntime;
use crate::tools::sandboxing::with_cached_approval;
use crate::zsh_exec_bridge::ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR;
use codex_execpolicy::Decision;
use codex_execpolicy::Policy;
use codex_execpolicy::RuleMatch;
use codex_network_proxy::NetworkProxy;
use codex_protocol::config_types::WindowsSandboxLevel;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::ReviewDecision;
use codex_protocol::protocol::SandboxPolicy;
use codex_shell_command::bash::parse_shell_lc_plain_commands;
use codex_shell_command::bash::parse_shell_lc_single_command_prefix;
#[cfg(unix)]
use codex_shell_escalation::unix::core_shell_escalation::ShellActionProvider;
#[cfg(unix)]
use codex_shell_escalation::unix::core_shell_escalation::ShellPolicyFactory;
#[cfg(unix)]
use codex_shell_escalation::unix::escalate_protocol::EscalateAction;
#[cfg(unix)]
use codex_shell_escalation::unix::escalate_server::ExecParams;
#[cfg(unix)]
use codex_shell_escalation::unix::escalate_server::ExecResult;
#[cfg(unix)]
use codex_shell_escalation::unix::escalate_server::SandboxState;
#[cfg(unix)]
use codex_shell_escalation::unix::escalate_server::ShellCommandExecutor;
#[cfg(unix)]
use codex_shell_escalation::unix::escalate_server::run_escalate_server;
#[cfg(unix)]
use codex_shell_escalation::unix::stopwatch::Stopwatch;
#[cfg(unix)]
use codex_utils_absolute_path::AbsolutePathBuf;
use futures::future::BoxFuture;
use shlex::try_join as shlex_try_join;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
#[cfg(unix)]
use std::sync::Arc;
#[cfg(unix)]
use std::time::Duration;
#[cfg(unix)]
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
#[derive(Clone, Debug)]
pub struct ShellRequest {
pub command: Vec<String>,
pub cwd: PathBuf,
pub timeout_ms: Option<u64>,
pub env: std::collections::HashMap<String, String>,
pub explicit_env_overrides: std::collections::HashMap<String, String>,
pub env: HashMap<String, String>,
pub explicit_env_overrides: HashMap<String, String>,
pub network: Option<NetworkProxy>,
pub sandbox_permissions: SandboxPermissions,
pub justification: Option<String>,
@@ -60,7 +101,7 @@ impl ShellRuntime {
Self
}
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
fn stdout_stream(ctx: &ToolCtx) -> Option<crate::exec::StdoutStream> {
Some(crate::exec::StdoutStream {
sub_id: ctx.turn.sub_id.clone(),
call_id: ctx.call_id.clone(),
@@ -73,6 +114,7 @@ impl Sandboxable for ShellRuntime {
fn sandbox_preference(&self) -> SandboxablePreference {
SandboxablePreference::Auto
}
fn escalate_on_failure(&self) -> bool {
true
}
@@ -146,11 +188,214 @@ impl Approvable<ShellRequest> for ShellRuntime {
}
}
#[cfg(unix)]
struct CoreShellActionProvider {
policy: Arc<RwLock<Policy>>,
session: std::sync::Arc<crate::codex::Session>,
turn: std::sync::Arc<crate::codex::TurnContext>,
call_id: String,
approval_policy: AskForApproval,
sandbox_policy: SandboxPolicy,
sandbox_permissions: SandboxPermissions,
}
#[cfg(unix)]
impl CoreShellActionProvider {
fn decision_driven_by_policy(matched_rules: &[RuleMatch], decision: Decision) -> bool {
matched_rules.iter().any(|rule_match| {
!matches!(rule_match, RuleMatch::HeuristicsRuleMatch { .. })
&& rule_match.decision() == decision
})
}
async fn prompt(
&self,
command: &[String],
workdir: &Path,
stopwatch: &Stopwatch,
) -> anyhow::Result<ReviewDecision> {
let command = command.to_vec();
let workdir = workdir.to_path_buf();
let session = self.session.clone();
let turn = self.turn.clone();
let call_id = self.call_id.clone();
Ok(stopwatch
.pause_for(async move {
session
.request_command_approval(
&turn, call_id, None, command, workdir, None, None, None,
)
.await
})
.await)
}
}
#[cfg(unix)]
#[async_trait::async_trait]
impl ShellActionProvider for CoreShellActionProvider {
async fn determine_action(
&self,
file: &Path,
argv: &[String],
workdir: &Path,
stopwatch: &Stopwatch,
) -> anyhow::Result<EscalateAction> {
let command = std::iter::once(file.to_string_lossy().to_string())
.chain(argv.iter().cloned())
.collect::<Vec<_>>();
let (commands, used_complex_parsing) =
if let Some(commands) = parse_shell_lc_plain_commands(&command) {
(commands, false)
} else if let Some(single_command) = parse_shell_lc_single_command_prefix(&command) {
(vec![single_command], true)
} else {
(vec![command.clone()], false)
};
let policy = self.policy.read().await;
let fallback = |cmd: &[String]| {
crate::exec_policy::render_decision_for_unmatched_command(
self.approval_policy,
&self.sandbox_policy,
cmd,
self.sandbox_permissions,
used_complex_parsing,
)
};
let evaluation = policy.check_multiple(commands.iter(), &fallback);
let decision_driven_by_policy =
Self::decision_driven_by_policy(&evaluation.matched_rules, evaluation.decision);
let needs_escalation =
self.sandbox_permissions.requires_escalated_permissions() || decision_driven_by_policy;
Ok(match evaluation.decision {
Decision::Forbidden => EscalateAction::Deny {
reason: Some("Execution forbidden by policy".to_string()),
},
Decision::Prompt => {
if self.approval_policy == AskForApproval::Never {
EscalateAction::Deny {
reason: Some("Execution forbidden by policy".to_string()),
}
} else if decision_driven_by_policy {
EscalateAction::Escalate
} else {
match self.prompt(&command, workdir, stopwatch).await? {
ReviewDecision::Approved
| ReviewDecision::ApprovedExecpolicyAmendment { .. }
| ReviewDecision::ApprovedForSession => {
if needs_escalation {
EscalateAction::Escalate
} else {
EscalateAction::Run
}
}
ReviewDecision::Denied => EscalateAction::Deny {
reason: Some("User denied execution".to_string()),
},
ReviewDecision::Abort => EscalateAction::Deny {
reason: Some("User cancelled execution".to_string()),
},
}
}
}
Decision::Allow => EscalateAction::Run,
})
}
}
#[cfg(unix)]
struct CoreShellCommandExecutor;
#[cfg(unix)]
#[async_trait::async_trait]
impl ShellCommandExecutor for CoreShellCommandExecutor {
async fn run(
&self,
command: Vec<String>,
cwd: PathBuf,
env: HashMap<String, String>,
cancel_rx: CancellationToken,
sandbox_state: &SandboxState,
) -> anyhow::Result<ExecResult> {
let result = crate::exec::process_exec_tool_call(
crate::exec::ExecParams {
command,
cwd,
expiration: crate::exec::ExecExpiration::Cancellation(cancel_rx),
env,
network: None,
sandbox_permissions: SandboxPermissions::UseDefault,
windows_sandbox_level: WindowsSandboxLevel::Disabled,
justification: None,
arg0: None,
},
&sandbox_state.sandbox_policy,
&sandbox_state.sandbox_cwd,
&sandbox_state.codex_linux_sandbox_exe,
sandbox_state.use_linux_sandbox_bwrap,
None,
)
.await?;
Ok(ExecResult {
exit_code: result.exit_code,
output: result.aggregated_output.text,
duration: result.duration,
timed_out: result.timed_out,
})
}
}
#[cfg(unix)]
fn shell_execve_wrapper() -> anyhow::Result<PathBuf> {
let exe = std::env::current_exe()?;
exe.parent()
.map(|parent| parent.join("codex-execve-wrapper"))
.ok_or_else(|| anyhow::anyhow!("failed to determine codex-execve-wrapper path"))
}
#[cfg(unix)]
fn shell_exec_zsh_path(path: &AbsolutePathBuf) -> PathBuf {
path.to_path_buf()
}
#[cfg(unix)]
fn map_exec_result(
sandbox: SandboxType,
result: ExecResult,
) -> Result<ExecToolCallOutput, ToolError> {
let output = ExecToolCallOutput {
exit_code: result.exit_code,
stdout: crate::exec::StreamOutput::new(result.output.clone()),
stderr: crate::exec::StreamOutput::new(String::new()),
aggregated_output: crate::exec::StreamOutput::new(result.output.clone()),
duration: result.duration,
timed_out: result.timed_out,
};
if result.timed_out {
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Timeout {
output: Box::new(output),
})));
}
if is_likely_sandbox_denied(sandbox, &output) {
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied {
output: Box::new(output),
network_policy_decision: None,
})));
}
Ok(output)
}
impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
fn network_approval_spec(
&self,
req: &ShellRequest,
_ctx: &ToolCtx<'_>,
_ctx: &ToolCtx,
) -> Option<NetworkApprovalSpec> {
req.network.as_ref()?;
Some(NetworkApprovalSpec {
@@ -163,17 +408,15 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
&mut self,
req: &ShellRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx<'_>,
ctx: &ToolCtx,
) -> Result<ExecToolCallOutput, ToolError> {
let base_command = &req.command;
let session_shell = ctx.session.user_shell();
let command = maybe_wrap_shell_lc_with_snapshot(
base_command,
session_shell.as_ref(),
&req.command,
ctx.session.user_shell().as_ref(),
&req.cwd,
&req.explicit_env_overrides,
);
let command = if matches!(session_shell.shell_type, ShellType::PowerShell)
let command = if matches!(ctx.session.user_shell().shell_type, ShellType::PowerShell)
&& ctx.session.features().enabled(Feature::PowershellUtf8)
{
prefix_powershell_script_with_utf8(&command)
@@ -181,21 +424,15 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
command
};
if ctx.session.features().enabled(Feature::ShellZshFork) {
let wrapper_socket_path = ctx
.session
.services
.zsh_exec_bridge
.next_wrapper_socket_path();
let mut zsh_fork_env = req.env.clone();
zsh_fork_env.insert(
ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR.to_string(),
wrapper_socket_path.to_string_lossy().to_string(),
);
#[cfg(unix)]
if let Some(shell_zsh_path) = ctx.session.services.shell_zsh_path.as_ref()
&& ctx.session.features().enabled(Feature::ShellZshFork)
&& matches!(ctx.session.user_shell().shell_type, ShellType::Zsh)
{
let spec = build_command_spec(
&command,
&req.cwd,
&zsh_fork_env,
&req.env,
req.timeout_ms.into(),
req.sandbox_permissions,
req.justification.clone(),
@@ -203,12 +440,52 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
let env = attempt
.env_for(spec, req.network.as_ref())
.map_err(|err| ToolError::Codex(err.into()))?;
return ctx
.session
.services
.zsh_exec_bridge
.execute_shell_request(&env, ctx.session, ctx.turn, &ctx.call_id)
.await;
let (_, args) = env
.command
.split_first()
.ok_or_else(|| ToolError::Rejected("command args are empty".to_string()))?;
let script = shlex_try_join(args.iter().map(String::as_str))
.map_err(|err| ToolError::Rejected(format!("serialize shell script: {err}")))?;
let effective_timeout = Duration::from_millis(
req.timeout_ms
.unwrap_or(crate::exec::DEFAULT_EXEC_COMMAND_TIMEOUT_MS),
);
let exec_policy = Arc::new(RwLock::new(
ctx.session.services.exec_policy.current().as_ref().clone(),
));
let sandbox_state = SandboxState {
sandbox_policy: ctx.turn.sandbox_policy.get().clone(),
codex_linux_sandbox_exe: attempt.codex_linux_sandbox_exe.cloned(),
sandbox_cwd: req.cwd.clone(),
use_linux_sandbox_bwrap: attempt.use_linux_sandbox_bwrap,
};
let exec_result = run_escalate_server(
ExecParams {
command: script,
workdir: req.cwd.to_string_lossy().to_string(),
timeout_ms: Some(effective_timeout.as_millis() as u64),
login: Some(false),
},
&sandbox_state,
shell_exec_zsh_path(shell_zsh_path),
shell_execve_wrapper().map_err(|err| ToolError::Rejected(format!("{err}")))?,
exec_policy.clone(),
ShellPolicyFactory::new(CoreShellActionProvider {
policy: Arc::clone(&exec_policy),
session: Arc::clone(&ctx.session),
turn: Arc::clone(&ctx.turn),
call_id: ctx.call_id.clone(),
approval_policy: ctx.turn.approval_policy.value(),
sandbox_policy: attempt.policy.clone(),
sandbox_permissions: req.sandbox_permissions,
}),
effective_timeout,
&CoreShellCommandExecutor,
)
.await
.map_err(|err| ToolError::Rejected(err.to_string()))?;
return map_exec_result(attempt.sandbox, exec_result);
}
let spec = build_command_spec(

View File

@@ -153,7 +153,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
fn network_approval_spec(
&self,
req: &UnifiedExecRequest,
_ctx: &ToolCtx<'_>,
_ctx: &ToolCtx,
) -> Option<NetworkApprovalSpec> {
req.network.as_ref()?;
Some(NetworkApprovalSpec {
@@ -166,7 +166,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
&mut self,
req: &UnifiedExecRequest,
attempt: &SandboxAttempt<'_>,
ctx: &ToolCtx<'_>,
ctx: &ToolCtx,
) -> Result<UnifiedExecProcess, ToolError> {
let base_command = &req.command;
let session_shell = ctx.session.user_shell();

View File

@@ -18,14 +18,14 @@ use codex_protocol::approvals::ExecPolicyAmendment;
use codex_protocol::approvals::NetworkApprovalContext;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::ReviewDecision;
use futures::Future;
use futures::future::BoxFuture;
use serde::Serialize;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::path::Path;
use futures::Future;
use futures::future::BoxFuture;
use serde::Serialize;
use std::sync::Arc;
#[derive(Clone, Default, Debug)]
pub(crate) struct ApprovalStore {
@@ -267,9 +267,9 @@ pub(crate) trait Sandboxable {
}
}
pub(crate) struct ToolCtx<'a> {
pub session: &'a Session,
pub turn: &'a TurnContext,
pub(crate) struct ToolCtx {
pub session: Arc<Session>,
pub turn: Arc<TurnContext>,
pub call_id: String,
pub tool_name: String,
}
@@ -281,7 +281,7 @@ pub(crate) enum ToolError {
}
pub(crate) trait ToolRuntime<Req, Out>: Approvable<Req> + Sandboxable {
fn network_approval_spec(&self, _req: &Req, _ctx: &ToolCtx<'_>) -> Option<NetworkApprovalSpec> {
fn network_approval_spec(&self, _req: &Req, _ctx: &ToolCtx) -> Option<NetworkApprovalSpec> {
None
}

View File

@@ -594,8 +594,8 @@ impl UnifiedExecProcessManager {
exec_approval_requirement,
};
let tool_ctx = ToolCtx {
session: context.session.as_ref(),
turn: context.turn.as_ref(),
session: context.session.clone(),
turn: context.turn.clone(),
call_id: context.call_id.clone(),
tool_name: "exec_command".to_string(),
};
@@ -604,7 +604,7 @@ impl UnifiedExecProcessManager {
&mut runtime,
&req,
&tool_ctx,
context.turn.as_ref(),
&context.turn,
context.turn.approval_policy.value(),
)
.await

View File

@@ -1,557 +0,0 @@
use crate::exec::ExecToolCallOutput;
use crate::tools::sandboxing::ToolError;
use std::path::PathBuf;
use tokio::sync::Mutex;
use uuid::Uuid;
#[cfg(unix)]
use crate::error::CodexErr;
#[cfg(unix)]
use crate::error::SandboxErr;
#[cfg(unix)]
use crate::protocol::EventMsg;
#[cfg(unix)]
use crate::protocol::ExecCommandOutputDeltaEvent;
#[cfg(unix)]
use crate::protocol::ExecOutputStream;
#[cfg(unix)]
use crate::protocol::ReviewDecision;
#[cfg(unix)]
use anyhow::Context as _;
#[cfg(unix)]
use codex_protocol::approvals::ExecPolicyAmendment;
#[cfg(unix)]
use codex_utils_pty::process_group::kill_child_process_group;
#[cfg(unix)]
use serde::Deserialize;
#[cfg(unix)]
use serde::Serialize;
#[cfg(unix)]
use std::io::Read;
#[cfg(unix)]
use std::io::Write;
#[cfg(unix)]
use std::time::Instant;
#[cfg(unix)]
use tokio::io::AsyncReadExt;
#[cfg(unix)]
use tokio::net::UnixListener;
#[cfg(unix)]
use tokio::net::UnixStream;
pub(crate) const ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR: &str =
"CODEX_ZSH_EXEC_BRIDGE_WRAPPER_SOCKET";
pub(crate) const ZSH_EXEC_WRAPPER_MODE_ENV_VAR: &str = "CODEX_ZSH_EXEC_WRAPPER_MODE";
#[cfg(unix)]
pub(crate) const EXEC_WRAPPER_ENV_VAR: &str = "EXEC_WRAPPER";
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub(crate) struct ZshExecBridgeSessionState {
pub(crate) initialized_session_id: Option<String>,
}
#[derive(Debug, Default)]
pub(crate) struct ZshExecBridge {
zsh_path: Option<PathBuf>,
state: Mutex<ZshExecBridgeSessionState>,
}
#[cfg(unix)]
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum WrapperIpcRequest {
ExecRequest {
request_id: String,
file: String,
argv: Vec<String>,
cwd: String,
},
}
#[cfg(unix)]
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum WrapperIpcResponse {
ExecResponse {
request_id: String,
action: WrapperExecAction,
reason: Option<String>,
},
}
#[cfg(unix)]
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum WrapperExecAction {
Run,
Deny,
}
impl ZshExecBridge {
pub(crate) fn new(zsh_path: Option<PathBuf>, _codex_home: PathBuf) -> Self {
Self {
zsh_path,
state: Mutex::new(ZshExecBridgeSessionState::default()),
}
}
pub(crate) async fn initialize_for_session(&self, session_id: &str) {
let mut state = self.state.lock().await;
state.initialized_session_id = Some(session_id.to_string());
}
pub(crate) async fn shutdown(&self) {
let mut state = self.state.lock().await;
state.initialized_session_id = None;
}
pub(crate) fn next_wrapper_socket_path(&self) -> PathBuf {
let socket_id = Uuid::new_v4().as_simple().to_string();
let temp_dir = std::env::temp_dir();
let canonical_temp_dir = temp_dir.canonicalize().unwrap_or(temp_dir);
canonical_temp_dir.join(format!("czs-{}.sock", &socket_id[..12]))
}
#[cfg(not(unix))]
pub(crate) async fn execute_shell_request(
&self,
_req: &crate::sandboxing::ExecRequest,
_session: &crate::codex::Session,
_turn: &crate::codex::TurnContext,
_call_id: &str,
) -> Result<ExecToolCallOutput, ToolError> {
let _ = &self.zsh_path;
Err(ToolError::Rejected(
"shell_zsh_fork is only supported on unix".to_string(),
))
}
#[cfg(unix)]
pub(crate) async fn execute_shell_request(
&self,
req: &crate::sandboxing::ExecRequest,
session: &crate::codex::Session,
turn: &crate::codex::TurnContext,
call_id: &str,
) -> Result<ExecToolCallOutput, ToolError> {
let zsh_path = self.zsh_path.clone().ok_or_else(|| {
ToolError::Rejected(
"shell_zsh_fork enabled, but zsh_path is not configured".to_string(),
)
})?;
let command = req.command.clone();
if command.is_empty() {
return Err(ToolError::Rejected("command args are empty".to_string()));
}
let wrapper_socket_path = req
.env
.get(ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR)
.map(PathBuf::from)
.unwrap_or_else(|| self.next_wrapper_socket_path());
let listener = {
let _ = std::fs::remove_file(&wrapper_socket_path);
UnixListener::bind(&wrapper_socket_path).map_err(|err| {
ToolError::Rejected(format!(
"bind wrapper socket at {}: {err}",
wrapper_socket_path.display()
))
})?
};
let wrapper_path = std::env::current_exe().map_err(|err| {
ToolError::Rejected(format!("resolve current executable path: {err}"))
})?;
let mut cmd = tokio::process::Command::new(&command[0]);
#[cfg(unix)]
if let Some(arg0) = &req.arg0 {
cmd.arg0(arg0);
}
if command.len() > 1 {
cmd.args(&command[1..]);
}
cmd.current_dir(&req.cwd);
cmd.stdin(std::process::Stdio::null());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
cmd.kill_on_drop(true);
cmd.env_clear();
cmd.envs(&req.env);
cmd.env(
ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR,
wrapper_socket_path.to_string_lossy().to_string(),
);
cmd.env(EXEC_WRAPPER_ENV_VAR, &wrapper_path);
cmd.env(ZSH_EXEC_WRAPPER_MODE_ENV_VAR, "1");
let mut child = cmd.spawn().map_err(|err| {
ToolError::Rejected(format!(
"failed to start zsh fork command {} with zsh_path {}: {err}",
command[0],
zsh_path.display()
))
})?;
let (stream_tx, mut stream_rx) =
tokio::sync::mpsc::unbounded_channel::<(ExecOutputStream, Vec<u8>)>();
if let Some(mut out) = child.stdout.take() {
let tx = stream_tx.clone();
tokio::spawn(async move {
let mut buf = [0_u8; 8192];
loop {
let read = match out.read(&mut buf).await {
Ok(0) => break,
Ok(n) => n,
Err(err) => {
tracing::warn!("zsh fork stdout read error: {err}");
break;
}
};
let _ = tx.send((ExecOutputStream::Stdout, buf[..read].to_vec()));
}
});
}
if let Some(mut err) = child.stderr.take() {
let tx = stream_tx.clone();
tokio::spawn(async move {
let mut buf = [0_u8; 8192];
loop {
let read = match err.read(&mut buf).await {
Ok(0) => break,
Ok(n) => n,
Err(err) => {
tracing::warn!("zsh fork stderr read error: {err}");
break;
}
};
let _ = tx.send((ExecOutputStream::Stderr, buf[..read].to_vec()));
}
});
}
drop(stream_tx);
let mut stdout_bytes = Vec::new();
let mut stderr_bytes = Vec::new();
let mut child_exit = None;
let mut timed_out = false;
let mut stream_open = true;
let mut user_rejected = false;
let start = Instant::now();
let expiration = req.expiration.clone().wait();
tokio::pin!(expiration);
while child_exit.is_none() || stream_open {
tokio::select! {
result = child.wait(), if child_exit.is_none() => {
child_exit = Some(result.map_err(|err| ToolError::Rejected(format!("wait for zsh fork command exit: {err}")))?);
}
stream = stream_rx.recv(), if stream_open => {
if let Some((output_stream, chunk)) = stream {
match output_stream {
ExecOutputStream::Stdout => stdout_bytes.extend_from_slice(&chunk),
ExecOutputStream::Stderr => stderr_bytes.extend_from_slice(&chunk),
}
session
.send_event(
turn,
EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
call_id: call_id.to_string(),
stream: output_stream,
chunk,
}),
)
.await;
} else {
stream_open = false;
}
}
accept_result = listener.accept(), if child_exit.is_none() => {
let (stream, _) = accept_result.map_err(|err| {
ToolError::Rejected(format!("failed to accept wrapper request: {err}"))
})?;
if self
.handle_wrapper_request(stream, req.justification.clone(), session, turn, call_id)
.await?
{
user_rejected = true;
}
}
_ = &mut expiration, if child_exit.is_none() => {
timed_out = true;
kill_child_process_group(&mut child).map_err(|err| {
ToolError::Rejected(format!("kill zsh fork command process group: {err}"))
})?;
child.start_kill().map_err(|err| {
ToolError::Rejected(format!("kill zsh fork command process: {err}"))
})?;
}
}
}
let _ = std::fs::remove_file(&wrapper_socket_path);
let status = child_exit.ok_or_else(|| {
ToolError::Rejected("zsh fork command did not return exit status".to_string())
})?;
if user_rejected {
return Err(ToolError::Rejected("rejected by user".to_string()));
}
let stdout_text = crate::text_encoding::bytes_to_string_smart(&stdout_bytes);
let stderr_text = crate::text_encoding::bytes_to_string_smart(&stderr_bytes);
let output = ExecToolCallOutput {
exit_code: status.code().unwrap_or(-1),
stdout: crate::exec::StreamOutput::new(stdout_text.clone()),
stderr: crate::exec::StreamOutput::new(stderr_text.clone()),
aggregated_output: crate::exec::StreamOutput::new(format!(
"{stdout_text}{stderr_text}"
)),
duration: start.elapsed(),
timed_out,
};
Self::map_exec_result(req.sandbox, output)
}
#[cfg(unix)]
async fn handle_wrapper_request(
&self,
mut stream: UnixStream,
approval_reason: Option<String>,
session: &crate::codex::Session,
turn: &crate::codex::TurnContext,
call_id: &str,
) -> Result<bool, ToolError> {
let mut request_buf = Vec::new();
stream.read_to_end(&mut request_buf).await.map_err(|err| {
ToolError::Rejected(format!("read wrapper request from socket: {err}"))
})?;
let request_line = String::from_utf8(request_buf).map_err(|err| {
ToolError::Rejected(format!("decode wrapper request as utf-8: {err}"))
})?;
let request = parse_wrapper_request_line(request_line.trim())?;
let (request_id, file, argv, cwd) = match request {
WrapperIpcRequest::ExecRequest {
request_id,
file,
argv,
cwd,
} => (request_id, file, argv, cwd),
};
let command_for_approval = if argv.is_empty() {
vec![file.clone()]
} else {
argv.clone()
};
let approval_id = Uuid::new_v4().to_string();
let decision = session
.request_command_approval(
turn,
call_id.to_string(),
Some(approval_id),
command_for_approval,
PathBuf::from(cwd),
approval_reason,
None,
None::<ExecPolicyAmendment>,
)
.await;
let (action, reason, user_rejected) = match decision {
ReviewDecision::Approved
| ReviewDecision::ApprovedForSession
| ReviewDecision::ApprovedExecpolicyAmendment { .. } => {
(WrapperExecAction::Run, None, false)
}
ReviewDecision::Denied => (
WrapperExecAction::Deny,
Some("command denied by host approval policy".to_string()),
true,
),
ReviewDecision::Abort => (
WrapperExecAction::Deny,
Some("command aborted by host approval policy".to_string()),
true,
),
};
write_json_line(
&mut stream,
&WrapperIpcResponse::ExecResponse {
request_id,
action,
reason,
},
)
.await?;
Ok(user_rejected)
}
#[cfg(unix)]
fn map_exec_result(
sandbox: crate::exec::SandboxType,
output: ExecToolCallOutput,
) -> Result<ExecToolCallOutput, ToolError> {
if output.timed_out {
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Timeout {
output: Box::new(output),
})));
}
if crate::exec::is_likely_sandbox_denied(sandbox, &output) {
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied {
output: Box::new(output),
network_policy_decision: None,
})));
}
Ok(output)
}
}
pub fn maybe_run_zsh_exec_wrapper_mode() -> anyhow::Result<bool> {
if std::env::var_os(ZSH_EXEC_WRAPPER_MODE_ENV_VAR).is_none() {
return Ok(false);
}
run_exec_wrapper_mode()?;
Ok(true)
}
fn run_exec_wrapper_mode() -> anyhow::Result<()> {
#[cfg(not(unix))]
{
anyhow::bail!("zsh exec wrapper mode is only supported on unix");
}
#[cfg(unix)]
{
use std::os::unix::net::UnixStream as StdUnixStream;
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
anyhow::bail!("exec wrapper mode requires target executable path");
}
let file = args[1].clone();
let argv = if args.len() > 2 {
args[2..].to_vec()
} else {
vec![file.clone()]
};
let cwd = std::env::current_dir()
.context("resolve wrapper cwd")?
.to_string_lossy()
.to_string();
let socket_path = std::env::var(ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR)
.context("missing wrapper socket path env var")?;
let request_id = Uuid::new_v4().to_string();
let request = WrapperIpcRequest::ExecRequest {
request_id: request_id.clone(),
file: file.clone(),
argv: argv.clone(),
cwd,
};
let mut stream = StdUnixStream::connect(&socket_path)
.with_context(|| format!("connect to wrapper socket at {socket_path}"))?;
let encoded = serde_json::to_string(&request).context("serialize wrapper request")?;
stream
.write_all(encoded.as_bytes())
.context("write wrapper request")?;
stream
.write_all(b"\n")
.context("write wrapper request newline")?;
stream
.shutdown(std::net::Shutdown::Write)
.context("shutdown wrapper write")?;
let mut response_buf = String::new();
stream
.read_to_string(&mut response_buf)
.context("read wrapper response")?;
let response: WrapperIpcResponse =
serde_json::from_str(response_buf.trim()).context("parse wrapper response")?;
let (response_request_id, action, reason) = match response {
WrapperIpcResponse::ExecResponse {
request_id,
action,
reason,
} => (request_id, action, reason),
};
if response_request_id != request_id {
anyhow::bail!(
"wrapper response request_id mismatch: expected {request_id}, got {response_request_id}"
);
}
if action == WrapperExecAction::Deny {
if let Some(reason) = reason {
tracing::warn!("execution denied: {reason}");
} else {
tracing::warn!("execution denied");
}
std::process::exit(1);
}
let mut command = std::process::Command::new(&file);
if argv.len() > 1 {
command.args(&argv[1..]);
}
command.env_remove(ZSH_EXEC_WRAPPER_MODE_ENV_VAR);
command.env_remove(ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR);
command.env_remove(EXEC_WRAPPER_ENV_VAR);
let status = command.status().context("spawn wrapped executable")?;
std::process::exit(status.code().unwrap_or(1));
}
}
#[cfg(unix)]
fn parse_wrapper_request_line(request_line: &str) -> Result<WrapperIpcRequest, ToolError> {
serde_json::from_str(request_line)
.map_err(|err| ToolError::Rejected(format!("parse wrapper request payload: {err}")))
}
#[cfg(unix)]
async fn write_json_line<W: tokio::io::AsyncWrite + Unpin, T: Serialize>(
writer: &mut W,
message: &T,
) -> Result<(), ToolError> {
let encoded = serde_json::to_string(message)
.map_err(|err| ToolError::Rejected(format!("serialize wrapper message: {err}")))?;
tokio::io::AsyncWriteExt::write_all(writer, encoded.as_bytes())
.await
.map_err(|err| ToolError::Rejected(format!("write wrapper message: {err}")))?;
tokio::io::AsyncWriteExt::write_all(writer, b"\n")
.await
.map_err(|err| ToolError::Rejected(format!("write wrapper newline: {err}")))?;
tokio::io::AsyncWriteExt::flush(writer)
.await
.map_err(|err| ToolError::Rejected(format!("flush wrapper message: {err}")))?;
Ok(())
}
#[cfg(all(test, unix))]
mod tests {
use super::*;
#[test]
fn parse_wrapper_request_line_rejects_malformed_json() {
let err = parse_wrapper_request_line("this-is-not-json").unwrap_err();
let ToolError::Rejected(message) = err else {
panic!("expected ToolError::Rejected");
};
assert!(message.starts_with("parse wrapper request payload:"));
}
}

View File

@@ -32,7 +32,16 @@ codex-core = { workspace = true }
codex-execpolicy = { workspace = true }
codex-protocol = { workspace = true }
codex-shell-command = { workspace = true }
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
codex-shell-escalation = { workspace = true }
||||||| base
libc = { workspace = true }
path-absolutize = { workspace = true }
=======
[target.'cfg(unix)'.dependencies]
codex-shell-escalation = { workspace = true }
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
rmcp = { workspace = true, default-features = false, features = [
"auth",
"elicitation",
@@ -50,7 +59,22 @@ schemars = { version = "1.2.1" }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
shlex = { workspace = true }
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }
||||||| base
socket2 = { workspace = true }
tokio = { workspace = true, features = [
"io-std",
"macros",
"process",
"rt-multi-thread",
"signal",
] }
tokio-util = { workspace = true }
=======
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }
tokio-util = { workspace = true }
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] }

View File

@@ -67,7 +67,7 @@ use codex_execpolicy::Decision;
use codex_execpolicy::Policy;
use codex_execpolicy::RuleMatch;
use codex_shell_command::is_dangerous_command::command_might_be_dangerous;
use codex_shell_escalation as shell_escalation;
use codex_shell_escalation::unix::escalate_client::run;
use rmcp::ErrorData as McpError;
use tokio::sync::RwLock;
use tracing_subscriber::EnvFilter;
@@ -160,7 +160,7 @@ pub async fn main_execve_wrapper() -> anyhow::Result<()> {
.init();
let ExecveWrapperCli { file, argv } = ExecveWrapperCli::parse();
let exit_code = shell_escalation::run(file, argv).await?;
let exit_code = run(file, argv).await?;
std::process::exit(exit_code);
}

View File

@@ -6,11 +6,19 @@ use anyhow::Context as _;
use anyhow::Result;
use codex_core::MCP_SANDBOX_STATE_CAPABILITY;
use codex_core::MCP_SANDBOX_STATE_METHOD;
use codex_core::SandboxState;
use codex_core::SandboxState as CoreSandboxState;
use codex_core::exec::process_exec_tool_call;
use codex_execpolicy::Policy;
use codex_protocol::config_types::WindowsSandboxLevel;
use codex_protocol::models::SandboxPermissions as ProtocolSandboxPermissions;
use codex_protocol::protocol::SandboxPolicy;
use codex_shell_escalation::EscalationPolicyFactory;
use codex_shell_escalation::run_escalate_server;
use codex_shell_escalation::unix::escalate_server::EscalationPolicyFactory;
use codex_shell_escalation::unix::escalate_server::ExecParams as ShellExecParams;
use codex_shell_escalation::unix::escalate_server::ExecResult as ShellExecResult;
use codex_shell_escalation::unix::escalate_server::SandboxState as ShellEscalationSandboxState;
use codex_shell_escalation::unix::escalate_server::ShellCommandExecutor;
use codex_shell_escalation::unix::escalate_server::run_escalate_server;
use codex_shell_escalation::unix::stopwatch::Stopwatch;
use rmcp::ErrorData as McpError;
use rmcp::RoleServer;
use rmcp::ServerHandler;
@@ -27,7 +35,9 @@ use rmcp::tool_handler;
use rmcp::tool_router;
use rmcp::transport::stdio;
use serde_json::json;
use std::collections::HashMap;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use crate::unix::mcp_escalation_policy::McpEscalationPolicy;
@@ -50,8 +60,8 @@ pub struct ExecResult {
pub timed_out: bool,
}
impl From<codex_shell_escalation::ExecResult> for ExecResult {
fn from(result: codex_shell_escalation::ExecResult) -> Self {
impl From<ShellExecResult> for ExecResult {
fn from(result: ShellExecResult) -> Self {
Self {
exit_code: result.exit_code,
output: result.output,
@@ -68,7 +78,7 @@ pub struct ExecTool {
execve_wrapper: PathBuf,
policy: Arc<RwLock<Policy>>,
preserve_program_paths: bool,
sandbox_state: Arc<RwLock<Option<SandboxState>>>,
sandbox_state: Arc<RwLock<Option<CoreSandboxState>>>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, rmcp::schemars::JsonSchema)]
@@ -83,7 +93,7 @@ pub struct ExecParams {
pub login: Option<bool>,
}
impl From<ExecParams> for codex_shell_escalation::ExecParams {
impl From<ExecParams> for ShellExecParams {
fn from(inner: ExecParams) -> Self {
Self {
command: inner.command,
@@ -99,14 +109,51 @@ struct McpEscalationPolicyFactory {
preserve_program_paths: bool,
}
struct McpShellCommandExecutor;
#[async_trait::async_trait]
impl ShellCommandExecutor for McpShellCommandExecutor {
async fn run(
&self,
command: Vec<String>,
cwd: PathBuf,
env: HashMap<String, String>,
cancel_rx: CancellationToken,
sandbox_state: &ShellEscalationSandboxState,
) -> anyhow::Result<ShellExecResult> {
let result = process_exec_tool_call(
codex_core::exec::ExecParams {
command,
cwd,
expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx),
env,
network: None,
sandbox_permissions: ProtocolSandboxPermissions::UseDefault,
windows_sandbox_level: WindowsSandboxLevel::Disabled,
justification: None,
arg0: None,
},
&sandbox_state.sandbox_policy,
&sandbox_state.sandbox_cwd,
&sandbox_state.codex_linux_sandbox_exe,
sandbox_state.use_linux_sandbox_bwrap,
None,
)
.await?;
Ok(ShellExecResult {
exit_code: result.exit_code,
output: result.aggregated_output.text,
duration: result.duration,
timed_out: result.timed_out,
})
}
}
impl EscalationPolicyFactory for McpEscalationPolicyFactory {
type Policy = McpEscalationPolicy;
fn create_policy(
&self,
policy: Arc<RwLock<Policy>>,
stopwatch: codex_shell_escalation::Stopwatch,
) -> Self::Policy {
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy {
McpEscalationPolicy::new(
policy,
self.context.clone(),
@@ -151,15 +198,21 @@ impl ExecTool {
.read()
.await
.clone()
.unwrap_or_else(|| SandboxState {
.unwrap_or_else(|| CoreSandboxState {
sandbox_policy: SandboxPolicy::new_read_only_policy(),
codex_linux_sandbox_exe: None,
sandbox_cwd: PathBuf::from(&params.workdir),
use_linux_sandbox_bwrap: false,
});
let shell_sandbox_state = ShellEscalationSandboxState {
sandbox_policy: sandbox_state.sandbox_policy.clone(),
codex_linux_sandbox_exe: sandbox_state.codex_linux_sandbox_exe.clone(),
sandbox_cwd: sandbox_state.sandbox_cwd.clone(),
use_linux_sandbox_bwrap: sandbox_state.use_linux_sandbox_bwrap,
};
let result = run_escalate_server(
params.into(),
&sandbox_state,
&shell_sandbox_state,
&self.bash_path,
&self.execve_wrapper,
self.policy.clone(),
@@ -168,6 +221,7 @@ impl ExecTool {
preserve_program_paths: self.preserve_program_paths,
},
effective_timeout,
&McpShellCommandExecutor,
)
.await
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
@@ -236,7 +290,7 @@ impl ServerHandler for ExecTool {
));
};
let Ok(sandbox_state) = serde_json::from_value::<SandboxState>(params.clone()) else {
let Ok(sandbox_state) = serde_json::from_value::<CoreSandboxState>(params.clone()) else {
return Err(McpError::invalid_params(
"failed to deserialize sandbox state".to_string(),
Some(params),

View File

@@ -2,9 +2,9 @@ use std::path::Path;
use codex_core::sandboxing::SandboxPermissions;
use codex_execpolicy::Policy;
use codex_shell_escalation::EscalateAction;
use codex_shell_escalation::EscalationPolicy;
use codex_shell_escalation::Stopwatch;
use codex_shell_escalation::unix::escalate_protocol::EscalateAction;
use codex_shell_escalation::unix::escalation_policy::EscalationPolicy;
use codex_shell_escalation::unix::stopwatch::Stopwatch;
use rmcp::ErrorData as McpError;
use rmcp::RoleServer;
use rmcp::model::CreateElicitationRequestParams;

View File

@@ -7,20 +7,21 @@ license.workspace = true
[dependencies]
anyhow = { workspace = true }
async-trait = { workspace = true }
codex-core = { workspace = true }
codex-execpolicy = { workspace = true }
codex-protocol = { workspace = true }
libc = { workspace = true }
serde_json = { workspace = true }
path-absolutize = { workspace = true }
serde = { workspace = true, features = ["derive"] }
socket2 = { workspace = true }
socket2 = { workspace = true, features = ["all"] }
tokio = { workspace = true, features = [
"io-std",
"net",
"macros",
"process",
"rt-multi-thread",
"signal",
"time",
] }
tokio-util = { workspace = true }
tracing = { workspace = true }

View File

@@ -1,3 +1,4 @@
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
#[cfg(unix)]
mod unix {
mod escalate_client;
@@ -19,3 +20,11 @@ mod unix {
#[cfg(unix)]
pub use unix::*;
||||||| base
=======
#[cfg(unix)]
pub mod unix;
#[cfg(unix)]
pub use unix::*;
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...

View File

@@ -1,3 +1,4 @@
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
use async_trait::async_trait;
use std::path::Path;
use std::sync::Arc;
@@ -69,3 +70,77 @@ impl EscalationPolicyFactory for ShellPolicyFactory {
}
}
}
||||||| base
=======
use async_trait::async_trait;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::escalate_protocol::EscalateAction;
use crate::escalation_policy::EscalationPolicy;
use crate::stopwatch::Stopwatch;
use crate::unix::escalate_server::EscalationPolicyFactory;
use codex_execpolicy::Policy;
#[async_trait]
pub trait ShellActionProvider: Send + Sync {
async fn determine_action(
&self,
file: &Path,
argv: &[String],
workdir: &Path,
stopwatch: &Stopwatch,
) -> anyhow::Result<EscalateAction>;
}
#[derive(Clone)]
pub struct ShellPolicyFactory {
provider: Arc<dyn ShellActionProvider>,
}
impl ShellPolicyFactory {
pub fn new<P>(provider: P) -> Self
where
P: ShellActionProvider + 'static,
{
Self {
provider: Arc::new(provider),
}
}
pub fn with_provider(provider: Arc<dyn ShellActionProvider>) -> Self {
Self { provider }
}
}
pub struct ShellEscalationPolicy {
provider: Arc<dyn ShellActionProvider>,
stopwatch: Stopwatch,
}
#[async_trait]
impl EscalationPolicy for ShellEscalationPolicy {
async fn determine_action(
&self,
file: &Path,
argv: &[String],
workdir: &Path,
) -> anyhow::Result<EscalateAction> {
self.provider
.determine_action(file, argv, workdir, &self.stopwatch)
.await
}
}
impl EscalationPolicyFactory for ShellPolicyFactory {
type Policy = ShellEscalationPolicy;
fn create_policy(&self, _policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy {
ShellEscalationPolicy {
provider: Arc::clone(&self.provider),
stopwatch,
}
}
}
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...

View File

@@ -1,3 +1,4 @@
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
use std::io;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd as _;
@@ -111,3 +112,113 @@ pub async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32> {
}
}
}
||||||| base
=======
use std::io;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd as _;
use std::os::fd::OwnedFd;
use anyhow::Context as _;
use crate::unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
use crate::unix::escalate_protocol::EXEC_WRAPPER_ENV_VAR;
use crate::unix::escalate_protocol::EscalateAction;
use crate::unix::escalate_protocol::EscalateRequest;
use crate::unix::escalate_protocol::EscalateResponse;
use crate::unix::escalate_protocol::LEGACY_BASH_EXEC_WRAPPER_ENV_VAR;
use crate::unix::escalate_protocol::SuperExecMessage;
use crate::unix::escalate_protocol::SuperExecResult;
use crate::unix::socket::AsyncDatagramSocket;
use crate::unix::socket::AsyncSocket;
fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
// TODO: we should defensively require only calling this once, since AsyncSocket will take ownership of the fd.
let client_fd = std::env::var(ESCALATE_SOCKET_ENV_VAR)?.parse::<i32>()?;
if client_fd < 0 {
return Err(anyhow::anyhow!(
"{ESCALATE_SOCKET_ENV_VAR} is not a valid file descriptor: {client_fd}"
));
}
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
}
pub async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32> {
let handshake_client = get_escalate_client()?;
let (server, client) = AsyncSocket::pair()?;
const HANDSHAKE_MESSAGE: [u8; 1] = [0];
handshake_client
.send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()])
.await
.context("failed to send handshake datagram")?;
let env = std::env::vars()
.filter(|(k, _)| {
!matches!(
k.as_str(),
ESCALATE_SOCKET_ENV_VAR | EXEC_WRAPPER_ENV_VAR | LEGACY_BASH_EXEC_WRAPPER_ENV_VAR
)
})
.collect();
client
.send(EscalateRequest {
file: file.clone().into(),
argv: argv.clone(),
workdir: std::env::current_dir()?,
env,
})
.await
.context("failed to send EscalateRequest")?;
let message = client
.receive::<EscalateResponse>()
.await
.context("failed to receive EscalateResponse")?;
match message.action {
EscalateAction::Escalate => {
// TODO: maybe we should send ALL open FDs (except the escalate client)?
let fds_to_send = [
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
];
// TODO: also forward signals over the super-exec socket
client
.send_with_fds(
SuperExecMessage {
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
},
&fds_to_send,
)
.await
.context("failed to send SuperExecMessage")?;
let SuperExecResult { exit_code } = client.receive::<SuperExecResult>().await?;
Ok(exit_code)
}
EscalateAction::Run => {
// We avoid std::process::Command here because we want to be as transparent as
// possible. std::os::unix::process::CommandExt has .exec() but it does some funky
// stuff with signal masks and dup2() on its standard FDs, which we don't want.
use std::ffi::CString;
let file = CString::new(file).context("NUL in file")?;
let argv_cstrs: Vec<CString> = argv
.iter()
.map(|s| CString::new(s.as_str()).context("NUL in argv"))
.collect::<Result<Vec<_>, _>>()?;
let mut argv: Vec<*const libc::c_char> =
argv_cstrs.iter().map(|s| s.as_ptr()).collect();
argv.push(std::ptr::null());
let err = unsafe {
libc::execv(file.as_ptr(), argv.as_ptr());
std::io::Error::last_os_error()
};
Err(err.into())
}
EscalateAction::Deny { .. } => Ok(1),
}
}
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...

View File

@@ -1,3 +1,4 @@
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
use std::collections::HashMap;
use std::os::fd::AsRawFd;
use std::path::Path;
@@ -375,3 +376,384 @@ mod tests {
server_task.await?
}
}
||||||| base
=======
use std::collections::HashMap;
use std::os::fd::AsRawFd;
use std::path::Path;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context as _;
use codex_execpolicy::Policy;
use codex_protocol::protocol::SandboxPolicy;
use path_absolutize::Absolutize as _;
use tokio::process::Command;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use crate::unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
use crate::unix::escalate_protocol::EXEC_WRAPPER_ENV_VAR;
use crate::unix::escalate_protocol::EscalateAction;
use crate::unix::escalate_protocol::EscalateRequest;
use crate::unix::escalate_protocol::EscalateResponse;
use crate::unix::escalate_protocol::LEGACY_BASH_EXEC_WRAPPER_ENV_VAR;
use crate::unix::escalate_protocol::SuperExecMessage;
use crate::unix::escalate_protocol::SuperExecResult;
use crate::unix::escalation_policy::EscalationPolicy;
use crate::unix::socket::AsyncDatagramSocket;
use crate::unix::socket::AsyncSocket;
use crate::unix::stopwatch::Stopwatch;
#[derive(Debug, Clone)]
pub struct SandboxState {
pub sandbox_policy: SandboxPolicy,
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub sandbox_cwd: PathBuf,
pub use_linux_sandbox_bwrap: bool,
}
#[async_trait::async_trait]
pub trait ShellCommandExecutor: Send + Sync {
async fn run(
&self,
command: Vec<String>,
cwd: PathBuf,
env: HashMap<String, String>,
cancel_rx: CancellationToken,
sandbox_state: &SandboxState,
) -> anyhow::Result<ExecResult>;
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct ExecParams {
/// The bash string to execute.
pub command: String,
/// The working directory to execute the command in. Must be an absolute path.
pub workdir: String,
/// The timeout for the command in milliseconds.
pub timeout_ms: Option<u64>,
/// Launch Bash with -lc instead of -c: defaults to true.
pub login: Option<bool>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct ExecResult {
pub exit_code: i32,
pub output: String,
pub duration: Duration,
pub timed_out: bool,
}
#[allow(clippy::module_name_repetitions)]
pub struct EscalateServer {
bash_path: PathBuf,
execve_wrapper: PathBuf,
policy: Arc<dyn EscalationPolicy>,
}
impl EscalateServer {
pub fn new<P>(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self
where
P: EscalationPolicy + Send + Sync + 'static,
{
Self {
bash_path,
execve_wrapper,
policy: Arc::new(policy),
}
}
pub async fn exec(
&self,
params: ExecParams,
cancel_rx: CancellationToken,
sandbox_state: &SandboxState,
command_executor: &dyn ShellCommandExecutor,
) -> anyhow::Result<ExecResult> {
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
let client_socket = escalate_client.into_inner();
let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone()));
let mut env = std::env::vars().collect::<HashMap<String, String>>();
env.insert(
ESCALATE_SOCKET_ENV_VAR.to_string(),
client_socket.as_raw_fd().to_string(),
);
env.insert(
EXEC_WRAPPER_ENV_VAR.to_string(),
self.execve_wrapper.to_string_lossy().to_string(),
);
env.insert(
LEGACY_BASH_EXEC_WRAPPER_ENV_VAR.to_string(),
self.execve_wrapper.to_string_lossy().to_string(),
);
let command = vec![
self.bash_path.to_string_lossy().to_string(),
if params.login == Some(false) {
"-c".to_string()
} else {
"-lc".to_string()
},
params.command,
];
let result = command_executor
.run(
command,
PathBuf::from(&params.workdir),
env,
cancel_rx,
sandbox_state,
)
.await?;
escalate_task.abort();
Ok(result)
}
}
/// Factory for creating escalation policy instances for a single shell run.
pub trait EscalationPolicyFactory {
type Policy: EscalationPolicy + Send + Sync + 'static;
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy;
}
#[allow(clippy::too_many_arguments)]
pub async fn run_escalate_server(
exec_params: ExecParams,
sandbox_state: &SandboxState,
shell_program: impl AsRef<Path>,
execve_wrapper: impl AsRef<Path>,
policy: Arc<RwLock<Policy>>,
escalation_policy_factory: impl EscalationPolicyFactory,
effective_timeout: Duration,
command_executor: &dyn ShellCommandExecutor,
) -> anyhow::Result<ExecResult> {
let stopwatch = Stopwatch::new(effective_timeout);
let cancel_token = stopwatch.cancellation_token();
let escalate_server = EscalateServer::new(
shell_program.as_ref().to_path_buf(),
execve_wrapper.as_ref().to_path_buf(),
escalation_policy_factory.create_policy(policy, stopwatch),
);
escalate_server
.exec(exec_params, cancel_token, sandbox_state, command_executor)
.await
}
async fn escalate_task(
socket: AsyncDatagramSocket,
policy: Arc<dyn EscalationPolicy>,
) -> anyhow::Result<()> {
loop {
let (_, mut fds) = socket.receive_with_fds().await?;
if fds.len() != 1 {
tracing::error!("expected 1 fd in datagram handshake, got {}", fds.len());
continue;
}
let stream_socket = AsyncSocket::from_fd(fds.remove(0))?;
let policy = policy.clone();
tokio::spawn(async move {
if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await {
tracing::error!("escalate session failed: {err:?}");
}
});
}
}
async fn handle_escalate_session_with_policy(
socket: AsyncSocket,
policy: Arc<dyn EscalationPolicy>,
) -> anyhow::Result<()> {
let EscalateRequest {
file,
argv,
workdir,
env,
} = socket.receive::<EscalateRequest>().await?;
let file = PathBuf::from(&file).absolutize()?.into_owned();
let workdir = PathBuf::from(&workdir).absolutize()?.into_owned();
let action = policy
.determine_action(file.as_path(), &argv, &workdir)
.await
.context("failed to determine escalation action")?;
tracing::debug!("decided {action:?} for {file:?} {argv:?} {workdir:?}");
match action {
EscalateAction::Run => {
socket
.send(EscalateResponse {
action: EscalateAction::Run,
})
.await?;
}
EscalateAction::Escalate => {
socket
.send(EscalateResponse {
action: EscalateAction::Escalate,
})
.await?;
let (msg, fds) = socket
.receive_with_fds::<SuperExecMessage>()
.await
.context("failed to receive SuperExecMessage")?;
if fds.len() != msg.fds.len() {
return Err(anyhow::anyhow!(
"mismatched number of fds in SuperExecMessage: {} in the message, {} from the control message",
msg.fds.len(),
fds.len()
));
}
if msg
.fds
.iter()
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
{
return Err(anyhow::anyhow!(
"overlapping fds not yet supported in SuperExecMessage"
));
}
let mut command = Command::new(file);
command
.args(&argv[1..])
.arg0(argv[0].clone())
.envs(&env)
.current_dir(&workdir)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null());
unsafe {
command.pre_exec(move || {
for (dst_fd, src_fd) in msg.fds.iter().zip(&fds) {
libc::dup2(src_fd.as_raw_fd(), *dst_fd);
}
Ok(())
});
}
let mut child = command.spawn()?;
let exit_status = child.wait().await?;
socket
.send(SuperExecResult {
exit_code: exit_status.code().unwrap_or(127),
})
.await?;
}
EscalateAction::Deny { reason } => {
socket
.send(EscalateResponse {
action: EscalateAction::Deny { reason },
})
.await?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::path::Path;
use std::path::PathBuf;
struct DeterministicEscalationPolicy {
action: EscalateAction,
}
#[async_trait::async_trait]
impl EscalationPolicy for DeterministicEscalationPolicy {
async fn determine_action(
&self,
_file: &Path,
_argv: &[String],
_workdir: &Path,
) -> anyhow::Result<EscalateAction> {
Ok(self.action.clone())
}
}
#[tokio::test]
async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let server_task = tokio::spawn(handle_escalate_session_with_policy(
server,
Arc::new(DeterministicEscalationPolicy {
action: EscalateAction::Run,
}),
));
let mut env = HashMap::new();
for i in 0..10 {
let value = "A".repeat(1024);
env.insert(format!("CODEX_TEST_VAR{i}"), value);
}
client
.send(EscalateRequest {
file: PathBuf::from("/bin/echo"),
argv: vec!["echo".to_string()],
workdir: PathBuf::from("/tmp"),
env,
})
.await?;
let response = client.receive::<EscalateResponse>().await?;
assert_eq!(
EscalateResponse {
action: EscalateAction::Run,
},
response
);
server_task.await?
}
#[tokio::test]
async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let server_task = tokio::spawn(handle_escalate_session_with_policy(
server,
Arc::new(DeterministicEscalationPolicy {
action: EscalateAction::Escalate,
}),
));
client
.send(EscalateRequest {
file: PathBuf::from("/bin/sh"),
argv: vec![
"sh".to_string(),
"-c".to_string(),
r#"if [ "$KEY" = VALUE ]; then exit 42; else exit 1; fi"#.to_string(),
],
workdir: std::env::current_dir()?,
env: HashMap::from([("KEY".to_string(), "VALUE".to_string())]),
})
.await?;
let response = client.receive::<EscalateResponse>().await?;
assert_eq!(
EscalateResponse {
action: EscalateAction::Escalate,
},
response
);
client
.send_with_fds(SuperExecMessage { fds: Vec::new() }, &[])
.await?;
let result = client.receive::<SuperExecResult>().await?;
assert_eq!(42, result.exit_code);
server_task.await?
}
}
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...

View File

@@ -1,3 +1,4 @@
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
pub mod escalate_client;
pub mod escalate_protocol;
pub mod escalate_server;
@@ -5,3 +6,13 @@ pub mod escalation_policy;
pub mod socket;
pub mod core_shell_escalation;
pub mod stopwatch;
||||||| base
=======
pub mod core_shell_escalation;
pub mod escalate_client;
pub mod escalate_protocol;
pub mod escalate_server;
pub mod escalation_policy;
pub mod socket;
pub mod stopwatch;
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...

View File

@@ -1,3 +1,4 @@
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
use libc::c_uint;
use serde::Deserialize;
use serde::Serialize;
@@ -505,3 +506,514 @@ mod tests {
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
}
}
||||||| base
=======
use libc::c_uint;
use serde::Deserialize;
use serde::Serialize;
use socket2::Domain;
use socket2::MaybeUninitSlice;
use socket2::MsgHdr;
use socket2::MsgHdrMut;
use socket2::Socket;
use socket2::Type;
use std::io::IoSlice;
use std::mem::MaybeUninit;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
use std::os::fd::OwnedFd;
use std::os::fd::RawFd;
use tokio::io::Interest;
use tokio::io::unix::AsyncFd;
const MAX_FDS_PER_MESSAGE: usize = 16;
const LENGTH_PREFIX_SIZE: usize = size_of::<u32>();
const MAX_DATAGRAM_SIZE: usize = 8192;
/// Converts a slice of MaybeUninit<T> to a slice of T.
///
/// The caller guarantees that every element of `buf` is initialized.
fn assume_init<T>(buf: &[MaybeUninit<T>]) -> &[T] {
unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) }
}
fn assume_init_slice<T, const N: usize>(buf: &[MaybeUninit<T>; N]) -> &[T; N] {
unsafe { &*(buf as *const [MaybeUninit<T>; N] as *const [T; N]) }
}
fn assume_init_vec<T>(mut buf: Vec<MaybeUninit<T>>) -> Vec<T> {
unsafe {
let ptr = buf.as_mut_ptr() as *mut T;
let len = buf.len();
let cap = buf.capacity();
std::mem::forget(buf);
Vec::from_raw_parts(ptr, len, cap)
}
}
fn control_space_for_fds(count: usize) -> usize {
unsafe { libc::CMSG_SPACE((count * size_of::<RawFd>()) as _) as usize }
}
/// Extracts the FDs from a SCM_RIGHTS control message.
fn extract_fds(control: &[u8]) -> Vec<OwnedFd> {
let mut fds = Vec::new();
let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
hdr.msg_control = control.as_ptr() as *mut libc::c_void;
hdr.msg_controllen = control.len() as _;
let hdr = hdr; // drop mut
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr };
while !cmsg.is_null() {
let level = unsafe { (*cmsg).cmsg_level };
let ty = unsafe { (*cmsg).cmsg_type };
if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::<RawFd>() };
let fd_count: usize = {
let cmsg_data_len =
unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize };
cmsg_data_len / size_of::<RawFd>()
};
for i in 0..fd_count {
let fd = unsafe { data_ptr.add(i).read() };
fds.push(unsafe { OwnedFd::from_raw_fd(fd) });
}
}
cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) };
}
fds
}
/// Read a frame from a SOCK_STREAM socket.
///
/// A frame is a message length prefix followed by a payload. FDs may be included in the control
/// message when receiving the frame header.
async fn read_frame(async_socket: &AsyncFd<Socket>) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
let (message_len, fds) = read_frame_header(async_socket).await?;
let payload = read_frame_payload(async_socket, message_len).await?;
Ok((payload, fds))
}
/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket.
async fn read_frame_header(
async_socket: &AsyncFd<Socket>,
) -> std::io::Result<(usize, Vec<OwnedFd>)> {
let mut header = [MaybeUninit::<u8>::uninit(); LENGTH_PREFIX_SIZE];
let mut filled = 0;
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
let mut captured_control = false;
while filled < LENGTH_PREFIX_SIZE {
let mut guard = async_socket.readable().await?;
// The first read should come with a control message containing any FDs.
let read = if !captured_control {
match guard.try_io(|inner| {
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
let (read, control_len) = {
let mut msg = MsgHdrMut::new()
.with_buffers(&mut bufs)
.with_control(&mut control);
let read = inner.get_ref().recvmsg(&mut msg, 0)?;
(read, msg.control_len())
};
control.truncate(control_len);
captured_control = true;
Ok(read)
}) {
Ok(Ok(read)) => read,
Ok(Err(err)) => return Err(err),
Err(_would_block) => continue,
}
} else {
match guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) {
Ok(Ok(read)) => read,
Ok(Err(err)) => return Err(err),
Err(_would_block) => continue,
}
};
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"socket closed while receiving frame header",
));
}
filled += read;
assert!(filled <= LENGTH_PREFIX_SIZE);
if filled == LENGTH_PREFIX_SIZE {
let len_bytes = assume_init_slice(&header);
let payload_len = u32::from_le_bytes(*len_bytes) as usize;
let fds = extract_fds(assume_init(&control));
return Ok((payload_len, fds));
}
}
unreachable!("header loop always returns")
}
/// Read `message_len` bytes from a SOCK_STREAM socket.
async fn read_frame_payload(
async_socket: &AsyncFd<Socket>,
message_len: usize,
) -> std::io::Result<Vec<u8>> {
if message_len == 0 {
return Ok(Vec::new());
}
let mut payload = vec![MaybeUninit::<u8>::uninit(); message_len];
let mut filled = 0;
while filled < message_len {
let mut guard = async_socket.readable().await?;
let read = match guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])) {
Ok(Ok(read)) => read,
Ok(Err(err)) => return Err(err),
Err(_would_block) => continue,
};
if read == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"socket closed while receiving frame payload",
));
}
filled += read;
assert!(filled <= message_len);
if filled == message_len {
return Ok(assume_init_vec(payload));
}
}
unreachable!("loop exits only after returning payload")
}
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
let control = make_control_message(fds)?;
let payload = [IoSlice::new(data)];
let msg = if control.is_empty() {
MsgHdr::new().with_buffers(&payload)
} else {
MsgHdr::new().with_buffers(&payload).with_control(&control)
};
let written = socket.sendmsg(&msg, 0)?;
if written != data.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
format!(
"short datagram write: wrote {written} bytes out of {}",
data.len()
),
));
}
Ok(())
}
fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> {
let len_u32 = u32::try_from(len).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("message too large: {len}"),
)
})?;
Ok(len_u32.to_le_bytes())
}
fn make_control_message(fds: &[OwnedFd]) -> std::io::Result<Vec<u8>> {
if fds.len() > MAX_FDS_PER_MESSAGE {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("too many fds: {}", fds.len()),
))
} else if fds.is_empty() {
Ok(Vec::new())
} else {
let mut control = vec![0u8; control_space_for_fds(fds.len())];
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len =
libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
for (i, fd) in fds.iter().enumerate() {
data_ptr.add(i).write(fd.as_raw_fd());
}
}
Ok(control)
}
}
fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
let mut buffer = vec![MaybeUninit::<u8>::uninit(); MAX_DATAGRAM_SIZE];
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
let (read, control_len) = {
let mut bufs = [MaybeUninitSlice::new(&mut buffer)];
let mut msg = MsgHdrMut::new()
.with_buffers(&mut bufs)
.with_control(&mut control);
let read = socket.recvmsg(&mut msg, 0)?;
(read, msg.control_len())
};
let data = assume_init(&buffer[..read]).to_vec();
let fds = extract_fds(assume_init(&control[..control_len]));
Ok((data, fds))
}
pub(crate) struct AsyncSocket {
inner: AsyncFd<Socket>,
}
impl AsyncSocket {
fn new(socket: Socket) -> std::io::Result<AsyncSocket> {
socket.set_nonblocking(true)?;
let async_socket = AsyncFd::new(socket)?;
Ok(AsyncSocket {
inner: async_socket,
})
}
pub fn from_fd(fd: OwnedFd) -> std::io::Result<AsyncSocket> {
AsyncSocket::new(Socket::from(fd))
}
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
let (server, client) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?;
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
}
pub async fn send_with_fds<T: Serialize>(
&self,
msg: T,
fds: &[OwnedFd],
) -> std::io::Result<()> {
let payload = serde_json::to_vec(&msg)?;
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + payload.len());
frame.extend_from_slice(&encode_length(payload.len())?);
frame.extend_from_slice(&payload);
send_stream_frame(&self.inner, &frame, fds).await
}
pub async fn receive_with_fds<T: for<'de> Deserialize<'de>>(
&self,
) -> std::io::Result<(T, Vec<OwnedFd>)> {
let (payload, fds) = read_frame(&self.inner).await?;
let message: T = serde_json::from_slice(&payload)?;
Ok((message, fds))
}
pub async fn send<T>(&self, msg: T) -> std::io::Result<()>
where
T: Serialize,
{
self.send_with_fds(&msg, &[]).await
}
pub async fn receive<T: for<'de> Deserialize<'de>>(&self) -> std::io::Result<T> {
let (msg, fds) = self.receive_with_fds().await?;
if !fds.is_empty() {
tracing::warn!("unexpected fds in receive: {}", fds.len());
}
Ok(msg)
}
pub fn into_inner(self) -> Socket {
self.inner.into_inner()
}
}
async fn send_stream_frame(
socket: &AsyncFd<Socket>,
frame: &[u8],
fds: &[OwnedFd],
) -> std::io::Result<()> {
let mut written = 0;
let mut include_fds = !fds.is_empty();
while written < frame.len() {
let mut guard = socket.writable().await?;
let bytes_written = match guard
.try_io(|inner| send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds))
{
Ok(Ok(bytes_written)) => bytes_written,
Ok(Err(err)) => return Err(err),
Err(_would_block) => continue,
};
if bytes_written == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"socket closed while sending frame payload",
));
}
written += bytes_written;
include_fds = false;
}
Ok(())
}
fn send_stream_chunk(
socket: &Socket,
frame: &[u8],
fds: &[OwnedFd],
include_fds: bool,
) -> std::io::Result<usize> {
let control = if include_fds {
make_control_message(fds)?
} else {
Vec::new()
};
let payload = [IoSlice::new(frame)];
let msg = if control.is_empty() {
MsgHdr::new().with_buffers(&payload)
} else {
MsgHdr::new().with_buffers(&payload).with_control(&control)
};
socket.sendmsg(&msg, 0)
}
pub(crate) struct AsyncDatagramSocket {
inner: AsyncFd<Socket>,
}
impl AsyncDatagramSocket {
fn new(socket: Socket) -> std::io::Result<Self> {
socket.set_nonblocking(true)?;
Ok(Self {
inner: AsyncFd::new(socket)?,
})
}
pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result<Self> {
Self::new(unsafe { Socket::from_raw_fd(fd) })
}
pub fn pair() -> std::io::Result<(Self, Self)> {
let (server, client) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?;
Ok((Self::new(server)?, Self::new(client)?))
}
pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
self.inner
.async_io(Interest::WRITABLE, |socket| {
send_datagram_bytes(socket, data, fds)
})
.await
}
pub async fn receive_with_fds(&self) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
self.inner
.async_io(Interest::READABLE, receive_datagram_bytes)
.await
}
pub fn into_inner(self) -> Socket {
self.inner.into_inner()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serde::Deserialize;
use serde::Serialize;
use std::os::fd::AsFd;
use std::os::fd::AsRawFd;
use tempfile::NamedTempFile;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
struct TestPayload {
id: i32,
label: String,
}
fn fd_list(count: usize) -> std::io::Result<Vec<OwnedFd>> {
let file = NamedTempFile::new()?;
let mut fds = Vec::new();
for _ in 0..count {
fds.push(file.as_fd().try_clone_to_owned()?);
}
Ok(fds)
}
#[tokio::test]
async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let payload = TestPayload {
id: 7,
label: "round-trip".to_string(),
};
let send_fds = fd_list(1)?;
let receive_task =
tokio::spawn(async move { server.receive_with_fds::<TestPayload>().await });
client.send_with_fds(payload.clone(), &send_fds).await?;
drop(send_fds);
let (received_payload, received_fds) = receive_task.await.unwrap()?;
assert_eq!(payload, received_payload);
assert_eq!(1, received_fds.len());
let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) };
assert!(
fd_status >= 0,
"expected received file descriptor to be valid, but got {fd_status}",
);
Ok(())
}
#[tokio::test]
async fn async_socket_handles_large_payload() -> std::io::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let payload = vec![b'A'; 10_000];
let receive_task = tokio::spawn(async move { server.receive::<Vec<u8>>().await });
client.send(payload.clone()).await?;
let received_payload = receive_task.await.unwrap()?;
assert_eq!(payload, received_payload);
Ok(())
}
#[tokio::test]
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
let (server, client) = AsyncDatagramSocket::pair()?;
let data = b"datagram payload".to_vec();
let send_fds = fd_list(1)?;
let receive_task = tokio::spawn(async move { server.receive_with_fds().await });
client.send_with_fds(&data, &send_fds).await?;
drop(send_fds);
let (received_bytes, received_fds) = receive_task.await.unwrap()?;
assert_eq!(data, received_bytes);
assert_eq!(1, received_fds.len());
Ok(())
}
#[test]
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
Ok(())
}
#[test]
fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
Ok(())
}
#[test]
fn encode_length_errors_for_oversized_messages() {
let err = encode_length(usize::MAX).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
}
#[tokio::test]
async fn receive_fails_when_peer_closes_before_header() {
let (server, client) = AsyncSocket::pair().expect("failed to create socket pair");
drop(client);
let err = server
.receive::<serde_json::Value>()
.await
.expect_err("expected read failure");
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
}
}
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...