mirror of
https://github.com/openai/codex.git
synced 2026-02-26 02:33:48 +00:00
Compare commits
2 Commits
dev/cc/new
...
pr12598
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62a42c31a7 | ||
|
|
86bfa68e42 |
26
codex-rs/Cargo.lock
generated
26
codex-rs/Cargo.lock
generated
@@ -1408,6 +1408,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"codex-apply-patch",
|
||||
"codex-linux-sandbox",
|
||||
"codex-shell-escalation",
|
||||
"codex-utils-home-dir",
|
||||
"dotenvy",
|
||||
"tempfile",
|
||||
@@ -1653,6 +1654,7 @@ dependencies = [
|
||||
"codex-rmcp-client",
|
||||
"codex-secrets",
|
||||
"codex-shell-command",
|
||||
"codex-shell-escalation",
|
||||
"codex-skills",
|
||||
"codex-state",
|
||||
"codex-utils-absolute-path",
|
||||
@@ -2199,6 +2201,7 @@ dependencies = [
|
||||
"which",
|
||||
]
|
||||
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
[[package]]
|
||||
name = "codex-shell-escalation"
|
||||
version = "0.0.0"
|
||||
@@ -2220,6 +2223,29 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
||||||| base
|
||||
=======
|
||||
[[package]]
|
||||
name = "codex-shell-escalation"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"codex-execpolicy",
|
||||
"codex-protocol",
|
||||
"libc",
|
||||
"path-absolutize",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socket2 0.6.2",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
[[package]]
|
||||
name = "codex-skills"
|
||||
version = "0.0.0"
|
||||
|
||||
@@ -24,11 +24,6 @@ struct AppServerArgs {
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
// Run wrapper mode only after arg0 dispatch so `codex-linux-sandbox`
|
||||
// invocations don't get misclassified as zsh exec-wrapper calls.
|
||||
if codex_core::maybe_run_zsh_exec_wrapper_mode()? {
|
||||
return Ok(());
|
||||
}
|
||||
let args = AppServerArgs::parse();
|
||||
let managed_config_path = managed_config_path_from_debug_env();
|
||||
let loader_overrides = LoaderOverrides {
|
||||
|
||||
@@ -19,3 +19,6 @@ codex-utils-home-dir = { workspace = true }
|
||||
dotenvy = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread"] }
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
codex-shell-escalation = { workspace = true }
|
||||
|
||||
@@ -12,6 +12,8 @@ use tempfile::TempDir;
|
||||
const LINUX_SANDBOX_ARG0: &str = "codex-linux-sandbox";
|
||||
const APPLY_PATCH_ARG0: &str = "apply_patch";
|
||||
const MISSPELLED_APPLY_PATCH_ARG0: &str = "applypatch";
|
||||
#[cfg(unix)]
|
||||
const EXECVE_WRAPPER_ARG0: &str = "codex-execve-wrapper";
|
||||
const LOCK_FILENAME: &str = ".lock";
|
||||
const TOKIO_WORKER_STACK_SIZE_BYTES: usize = 16 * 1024 * 1024;
|
||||
|
||||
@@ -39,6 +41,32 @@ pub fn arg0_dispatch() -> Option<Arg0PathEntryGuard> {
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("");
|
||||
|
||||
#[cfg(unix)]
|
||||
if exe_name == EXECVE_WRAPPER_ARG0 {
|
||||
let mut args = std::env::args();
|
||||
let _ = args.next();
|
||||
let file = match args.next() {
|
||||
Some(file) => file,
|
||||
None => std::process::exit(1),
|
||||
};
|
||||
let argv = args.collect::<Vec<_>>();
|
||||
|
||||
let runtime = match tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
{
|
||||
Ok(runtime) => runtime,
|
||||
Err(_) => std::process::exit(1),
|
||||
};
|
||||
let exit_code = runtime.block_on(codex_shell_escalation::unix::escalate_client::run(
|
||||
file, argv,
|
||||
));
|
||||
match exit_code {
|
||||
Ok(exit_code) => std::process::exit(exit_code),
|
||||
Err(_) => std::process::exit(1),
|
||||
}
|
||||
}
|
||||
|
||||
if exe_name == LINUX_SANDBOX_ARG0 {
|
||||
// Safety: [`run_main`] never returns.
|
||||
codex_linux_sandbox::run_main();
|
||||
@@ -227,6 +255,8 @@ pub fn prepend_path_entry_for_codex_aliases() -> std::io::Result<Arg0PathEntryGu
|
||||
MISSPELLED_APPLY_PATCH_ARG0,
|
||||
#[cfg(target_os = "linux")]
|
||||
LINUX_SANDBOX_ARG0,
|
||||
#[cfg(unix)]
|
||||
EXECVE_WRAPPER_ARG0,
|
||||
] {
|
||||
let exe = std::env::current_exe()?;
|
||||
|
||||
|
||||
@@ -544,11 +544,6 @@ fn stage_str(stage: codex_core::features::Stage) -> &'static str {
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
// Run wrapper mode only after arg0 dispatch so `codex-linux-sandbox`
|
||||
// invocations don't get misclassified as zsh exec-wrapper calls.
|
||||
if codex_core::maybe_run_zsh_exec_wrapper_mode()? {
|
||||
return Ok(());
|
||||
}
|
||||
cli_main(codex_linux_sandbox_exe).await?;
|
||||
Ok(())
|
||||
})
|
||||
|
||||
@@ -138,6 +138,9 @@ windows-sys = { version = "0.52", features = [
|
||||
[target.'cfg(any(target_os = "freebsd", target_os = "openbsd"))'.dependencies]
|
||||
keyring = { workspace = true, features = ["sync-secret-service"] }
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
codex-shell-escalation = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
||||
assert_matches = { workspace = true }
|
||||
|
||||
@@ -249,7 +249,6 @@ use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use crate::util::backoff;
|
||||
use crate::windows_sandbox::WindowsSandboxLevelExt;
|
||||
use crate::zsh_exec_bridge::ZshExecBridge;
|
||||
use codex_async_utils::OrCancelExt;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_otel::TelemetryAuthMode;
|
||||
@@ -1208,7 +1207,8 @@ impl Session {
|
||||
"zsh fork feature enabled, but `zsh_path` is not configured; set `zsh_path` in config.toml"
|
||||
)
|
||||
})?;
|
||||
shell::get_shell(shell::ShellType::Zsh, Some(zsh_path)).ok_or_else(|| {
|
||||
let zsh_path = zsh_path.to_path_buf();
|
||||
shell::get_shell(shell::ShellType::Zsh, Some(&zsh_path)).ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"zsh fork feature enabled, but zsh_path `{}` is not usable; set `zsh_path` to a valid zsh executable",
|
||||
zsh_path.display()
|
||||
@@ -1287,12 +1287,6 @@ impl Session {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let zsh_exec_bridge =
|
||||
ZshExecBridge::new(config.zsh_path.clone(), config.codex_home.clone());
|
||||
zsh_exec_bridge
|
||||
.initialize_for_session(&conversation_id.to_string())
|
||||
.await;
|
||||
|
||||
let services = SessionServices {
|
||||
// Initialize the MCP connection manager with an uninitialized
|
||||
// instance. It will be replaced with one created via
|
||||
@@ -1308,7 +1302,7 @@ impl Session {
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
config.background_terminal_max_timeout,
|
||||
),
|
||||
zsh_exec_bridge,
|
||||
shell_zsh_path: config.zsh_path.clone(),
|
||||
analytics_events_client: AnalyticsEventsClient::new(
|
||||
Arc::clone(&config),
|
||||
Arc::clone(&auth_manager),
|
||||
@@ -4227,7 +4221,6 @@ mod handlers {
|
||||
.unified_exec_manager
|
||||
.terminate_all_processes()
|
||||
.await;
|
||||
sess.services.zsh_exec_bridge.shutdown().await;
|
||||
info!("Shutting down Codex instance");
|
||||
let history = sess.clone_history().await;
|
||||
let turn_count = history
|
||||
@@ -7895,7 +7888,7 @@ mod tests {
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
config.background_terminal_max_timeout,
|
||||
),
|
||||
zsh_exec_bridge: ZshExecBridge::default(),
|
||||
shell_zsh_path: None,
|
||||
analytics_events_client: AnalyticsEventsClient::new(
|
||||
Arc::clone(&config),
|
||||
Arc::clone(&auth_manager),
|
||||
@@ -8048,7 +8041,7 @@ mod tests {
|
||||
unified_exec_manager: UnifiedExecProcessManager::new(
|
||||
config.background_terminal_max_timeout,
|
||||
),
|
||||
zsh_exec_bridge: ZshExecBridge::default(),
|
||||
shell_zsh_path: None,
|
||||
analytics_events_client: AnalyticsEventsClient::new(
|
||||
Arc::clone(&config),
|
||||
Arc::clone(&auth_manager),
|
||||
|
||||
@@ -372,7 +372,7 @@ pub struct Config {
|
||||
pub js_repl_node_module_dirs: Vec<PathBuf>,
|
||||
|
||||
/// Optional absolute path to patched zsh used by zsh-exec-bridge-backed shell execution.
|
||||
pub zsh_path: Option<PathBuf>,
|
||||
pub zsh_path: Option<AbsolutePathBuf>,
|
||||
|
||||
/// Value to use for `reasoning.effort` when making a request using the
|
||||
/// Responses API.
|
||||
@@ -1484,7 +1484,7 @@ pub struct ConfigOverrides {
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub js_repl_node_path: Option<PathBuf>,
|
||||
pub js_repl_node_module_dirs: Option<Vec<PathBuf>>,
|
||||
pub zsh_path: Option<PathBuf>,
|
||||
pub zsh_path: Option<AbsolutePathBuf>,
|
||||
pub base_instructions: Option<String>,
|
||||
pub developer_instructions: Option<String>,
|
||||
pub personality: Option<Personality>,
|
||||
@@ -1905,8 +1905,8 @@ impl Config {
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let zsh_path = zsh_path_override
|
||||
.or(config_profile.zsh_path.map(Into::into))
|
||||
.or(cfg.zsh_path.map(Into::into));
|
||||
.or(config_profile.zsh_path)
|
||||
.or(cfg.zsh_path);
|
||||
|
||||
let review_model = override_review_model.or(cfg.review_model);
|
||||
|
||||
|
||||
@@ -109,7 +109,6 @@ pub mod terminal;
|
||||
mod tools;
|
||||
pub mod turn_diff_tracker;
|
||||
mod turn_metadata;
|
||||
mod zsh_exec_bridge;
|
||||
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||
pub use rollout::INTERACTIVE_SESSION_SOURCES;
|
||||
pub use rollout::RolloutRecorder;
|
||||
@@ -144,7 +143,17 @@ pub(crate) use codex_shell_command::is_safe_command;
|
||||
pub(crate) use codex_shell_command::parse_command;
|
||||
pub(crate) use codex_shell_command::powershell;
|
||||
|
||||
pub use client::ModelClient;
|
||||
pub use client::ModelClientSession;
|
||||
pub use client::ResponsesWebsocketVersion;
|
||||
pub use client::X_CODEX_TURN_METADATA_HEADER;
|
||||
pub use client::ws_version_from_features;
|
||||
pub use client_common::Prompt;
|
||||
pub use client_common::REVIEW_PROMPT;
|
||||
pub use client_common::ResponseEvent;
|
||||
pub use client_common::ResponseStream;
|
||||
pub use compact::content_items_to_text;
|
||||
pub use event_mapping::parse_turn_item;
|
||||
pub use exec_policy::ExecPolicyError;
|
||||
pub use exec_policy::check_execpolicy_for_warnings;
|
||||
pub use exec_policy::format_exec_policy_error_with_source;
|
||||
@@ -153,18 +162,6 @@ pub use file_watcher::FileWatcherEvent;
|
||||
pub use safety::get_platform_sandbox;
|
||||
pub use tools::spec::parse_tool_input_schema;
|
||||
pub use turn_metadata::build_turn_metadata_header;
|
||||
pub use zsh_exec_bridge::maybe_run_zsh_exec_wrapper_mode;
|
||||
|
||||
pub use client::ModelClient;
|
||||
pub use client::ModelClientSession;
|
||||
pub use client::ResponsesWebsocketVersion;
|
||||
pub use client::ws_version_from_features;
|
||||
pub use client_common::Prompt;
|
||||
pub use client_common::REVIEW_PROMPT;
|
||||
pub use client_common::ResponseEvent;
|
||||
pub use client_common::ResponseStream;
|
||||
pub use compact::content_items_to_text;
|
||||
pub use event_mapping::parse_turn_item;
|
||||
pub mod compact;
|
||||
pub mod memory_trace;
|
||||
pub mod otel_init;
|
||||
|
||||
@@ -163,19 +163,13 @@ impl SandboxManager {
|
||||
SandboxType::MacosSeatbelt => {
|
||||
let mut seatbelt_env = HashMap::new();
|
||||
seatbelt_env.insert(CODEX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
|
||||
let zsh_exec_bridge_wrapper_socket = env
|
||||
.get(crate::zsh_exec_bridge::ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR)
|
||||
.map(PathBuf::from);
|
||||
let zsh_exec_bridge_allowed_unix_sockets = zsh_exec_bridge_wrapper_socket
|
||||
.as_ref()
|
||||
.map_or_else(Vec::new, |path| vec![path.clone()]);
|
||||
let mut args = create_seatbelt_command_args(
|
||||
command.clone(),
|
||||
policy,
|
||||
sandbox_policy_cwd,
|
||||
enforce_managed_network,
|
||||
network,
|
||||
&zsh_exec_bridge_allowed_unix_sockets,
|
||||
&[],
|
||||
);
|
||||
let mut full_command = Vec::with_capacity(1 + args.len());
|
||||
full_command.push(MACOS_PATH_TO_SEATBELT_EXECUTABLE.to_string());
|
||||
|
||||
@@ -15,9 +15,9 @@ use crate::state_db::StateDbHandle;
|
||||
use crate::tools::network_approval::NetworkApprovalService;
|
||||
use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use crate::zsh_exec_bridge::ZshExecBridge;
|
||||
use codex_hooks::Hooks;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::watch;
|
||||
@@ -27,7 +27,7 @@ pub(crate) struct SessionServices {
|
||||
pub(crate) mcp_connection_manager: Arc<RwLock<McpConnectionManager>>,
|
||||
pub(crate) mcp_startup_cancellation_token: Mutex<CancellationToken>,
|
||||
pub(crate) unified_exec_manager: UnifiedExecProcessManager,
|
||||
pub(crate) zsh_exec_bridge: ZshExecBridge,
|
||||
pub(crate) shell_zsh_path: Option<AbsolutePathBuf>,
|
||||
pub(crate) analytics_events_client: AnalyticsEventsClient,
|
||||
pub(crate) hooks: Hooks,
|
||||
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
|
||||
|
||||
@@ -31,6 +31,7 @@ use async_trait::async_trait;
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::ApplyPatchFileChange;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ApplyPatchHandler;
|
||||
|
||||
@@ -139,8 +140,8 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
let mut orchestrator = ToolOrchestrator::new();
|
||||
let mut runtime = ApplyPatchRuntime::new();
|
||||
let tool_ctx = ToolCtx {
|
||||
session: session.as_ref(),
|
||||
turn: turn.as_ref(),
|
||||
session: session.clone(),
|
||||
turn: turn.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tool_name: tool_name.to_string(),
|
||||
};
|
||||
@@ -149,7 +150,7 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
&mut runtime,
|
||||
&req,
|
||||
&tool_ctx,
|
||||
&turn,
|
||||
turn.as_ref(),
|
||||
turn.approval_policy.value(),
|
||||
)
|
||||
.await
|
||||
@@ -193,8 +194,8 @@ pub(crate) async fn intercept_apply_patch(
|
||||
command: &[String],
|
||||
cwd: &Path,
|
||||
timeout_ms: Option<u64>,
|
||||
session: &Session,
|
||||
turn: &TurnContext,
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
tracker: Option<&SharedTurnDiffTracker>,
|
||||
call_id: &str,
|
||||
tool_name: &str,
|
||||
@@ -203,11 +204,13 @@ pub(crate) async fn intercept_apply_patch(
|
||||
codex_apply_patch::MaybeApplyPatchVerified::Body(changes) => {
|
||||
session
|
||||
.record_model_warning(
|
||||
format!("apply_patch was requested via {tool_name}. Use the apply_patch tool instead of exec_command."),
|
||||
turn,
|
||||
format!(
|
||||
"apply_patch was requested via {tool_name}. Use the apply_patch tool instead of exec_command."
|
||||
),
|
||||
turn.as_ref(),
|
||||
)
|
||||
.await;
|
||||
match apply_patch::apply_patch(turn, changes).await {
|
||||
match apply_patch::apply_patch(turn.as_ref(), changes).await {
|
||||
InternalApplyPatchInvocation::Output(item) => {
|
||||
let content = item?;
|
||||
Ok(Some(ToolOutput::Function {
|
||||
@@ -219,8 +222,12 @@ pub(crate) async fn intercept_apply_patch(
|
||||
let changes = convert_apply_patch_to_protocol(&apply.action);
|
||||
let approval_keys = file_paths_for_action(&apply.action);
|
||||
let emitter = ToolEmitter::apply_patch(changes.clone(), apply.auto_approved);
|
||||
let event_ctx =
|
||||
ToolEventCtx::new(session, turn, call_id, tracker.as_ref().copied());
|
||||
let event_ctx = ToolEventCtx::new(
|
||||
session.as_ref(),
|
||||
turn.as_ref(),
|
||||
call_id,
|
||||
tracker.as_ref().copied(),
|
||||
);
|
||||
emitter.begin(event_ctx).await;
|
||||
|
||||
let req = ApplyPatchRequest {
|
||||
@@ -235,8 +242,8 @@ pub(crate) async fn intercept_apply_patch(
|
||||
let mut orchestrator = ToolOrchestrator::new();
|
||||
let mut runtime = ApplyPatchRuntime::new();
|
||||
let tool_ctx = ToolCtx {
|
||||
session,
|
||||
turn,
|
||||
session: session.clone(),
|
||||
turn: turn.clone(),
|
||||
call_id: call_id.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
};
|
||||
@@ -245,13 +252,17 @@ pub(crate) async fn intercept_apply_patch(
|
||||
&mut runtime,
|
||||
&req,
|
||||
&tool_ctx,
|
||||
turn,
|
||||
turn.as_ref(),
|
||||
turn.approval_policy.value(),
|
||||
)
|
||||
.await
|
||||
.map(|result| result.output);
|
||||
let event_ctx =
|
||||
ToolEventCtx::new(session, turn, call_id, tracker.as_ref().copied());
|
||||
let event_ctx = ToolEventCtx::new(
|
||||
session.as_ref(),
|
||||
turn.as_ref(),
|
||||
call_id,
|
||||
tracker.as_ref().copied(),
|
||||
);
|
||||
let content = emitter.finish(event_ctx, out).await?;
|
||||
Ok(Some(ToolOutput::Function {
|
||||
body: FunctionCallOutputBody::Text(content),
|
||||
|
||||
@@ -296,8 +296,8 @@ impl ShellHandler {
|
||||
&exec_params.command,
|
||||
&exec_params.cwd,
|
||||
exec_params.expiration.timeout_ms(),
|
||||
session.as_ref(),
|
||||
turn.as_ref(),
|
||||
session.clone(),
|
||||
turn.clone(),
|
||||
Some(&tracker),
|
||||
&call_id,
|
||||
tool_name.as_str(),
|
||||
@@ -343,8 +343,8 @@ impl ShellHandler {
|
||||
let mut orchestrator = ToolOrchestrator::new();
|
||||
let mut runtime = ShellRuntime::new();
|
||||
let tool_ctx = ToolCtx {
|
||||
session: session.as_ref(),
|
||||
turn: turn.as_ref(),
|
||||
session: session.clone(),
|
||||
turn: turn.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tool_name,
|
||||
};
|
||||
|
||||
@@ -172,8 +172,8 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
&command,
|
||||
&cwd,
|
||||
Some(yield_time_ms),
|
||||
context.session.as_ref(),
|
||||
context.turn.as_ref(),
|
||||
context.session.clone(),
|
||||
context.turn.clone(),
|
||||
Some(&tracker),
|
||||
&context.call_id,
|
||||
tool_name.as_str(),
|
||||
|
||||
@@ -48,7 +48,7 @@ impl ToolOrchestrator {
|
||||
async fn run_attempt<Rq, Out, T>(
|
||||
tool: &mut T,
|
||||
req: &Rq,
|
||||
tool_ctx: &ToolCtx<'_>,
|
||||
tool_ctx: &ToolCtx,
|
||||
attempt: &SandboxAttempt<'_>,
|
||||
has_managed_network_requirements: bool,
|
||||
) -> (Result<Out, ToolError>, Option<DeferredNetworkApproval>)
|
||||
@@ -56,7 +56,7 @@ impl ToolOrchestrator {
|
||||
T: ToolRuntime<Rq, Out>,
|
||||
{
|
||||
let network_approval = begin_network_approval(
|
||||
tool_ctx.session,
|
||||
&tool_ctx.session,
|
||||
&tool_ctx.turn.sub_id,
|
||||
&tool_ctx.call_id,
|
||||
has_managed_network_requirements,
|
||||
@@ -65,8 +65,8 @@ impl ToolOrchestrator {
|
||||
.await;
|
||||
|
||||
let attempt_tool_ctx = ToolCtx {
|
||||
session: tool_ctx.session,
|
||||
turn: tool_ctx.turn,
|
||||
session: tool_ctx.session.clone(),
|
||||
turn: tool_ctx.turn.clone(),
|
||||
call_id: tool_ctx.call_id.clone(),
|
||||
tool_name: tool_ctx.tool_name.clone(),
|
||||
};
|
||||
@@ -79,7 +79,7 @@ impl ToolOrchestrator {
|
||||
match network_approval.mode() {
|
||||
NetworkApprovalMode::Immediate => {
|
||||
let finalize_result =
|
||||
finish_immediate_network_approval(tool_ctx.session, network_approval).await;
|
||||
finish_immediate_network_approval(&tool_ctx.session, network_approval).await;
|
||||
if let Err(err) = finalize_result {
|
||||
return (Err(err), None);
|
||||
}
|
||||
@@ -88,7 +88,7 @@ impl ToolOrchestrator {
|
||||
NetworkApprovalMode::Deferred => {
|
||||
let deferred = network_approval.into_deferred();
|
||||
if run_result.is_err() {
|
||||
finish_deferred_network_approval(tool_ctx.session, deferred).await;
|
||||
finish_deferred_network_approval(&tool_ctx.session, deferred).await;
|
||||
return (run_result, None);
|
||||
}
|
||||
(run_result, deferred)
|
||||
@@ -100,7 +100,7 @@ impl ToolOrchestrator {
|
||||
&mut self,
|
||||
tool: &mut T,
|
||||
req: &Rq,
|
||||
tool_ctx: &ToolCtx<'_>,
|
||||
tool_ctx: &ToolCtx,
|
||||
turn_ctx: &crate::codex::TurnContext,
|
||||
approval_policy: AskForApproval,
|
||||
) -> Result<OrchestratorRunResult<Out>, ToolError>
|
||||
@@ -128,7 +128,7 @@ impl ToolOrchestrator {
|
||||
}
|
||||
ExecApprovalRequirement::NeedsApproval { reason, .. } => {
|
||||
let approval_ctx = ApprovalCtx {
|
||||
session: tool_ctx.session,
|
||||
session: &tool_ctx.session,
|
||||
turn: turn_ctx,
|
||||
call_id: &tool_ctx.call_id,
|
||||
retry_reason: reason,
|
||||
@@ -256,7 +256,7 @@ impl ToolOrchestrator {
|
||||
&& network_approval_context.is_none();
|
||||
if !bypass_retry_approval {
|
||||
let approval_ctx = ApprovalCtx {
|
||||
session: tool_ctx.session,
|
||||
session: &tool_ctx.session,
|
||||
turn: turn_ctx,
|
||||
call_id: &tool_ctx.call_id,
|
||||
retry_reason: Some(retry_reason),
|
||||
|
||||
@@ -70,7 +70,7 @@ impl ApplyPatchRuntime {
|
||||
})
|
||||
}
|
||||
|
||||
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
|
||||
fn stdout_stream(ctx: &ToolCtx) -> Option<crate::exec::StdoutStream> {
|
||||
Some(crate::exec::StdoutStream {
|
||||
sub_id: ctx.turn.sub_id.clone(),
|
||||
call_id: ctx.call_id.clone(),
|
||||
@@ -156,7 +156,7 @@ impl ToolRuntime<ApplyPatchRequest, ExecToolCallOutput> for ApplyPatchRuntime {
|
||||
&mut self,
|
||||
req: &ApplyPatchRequest,
|
||||
attempt: &SandboxAttempt<'_>,
|
||||
ctx: &ToolCtx<'_>,
|
||||
ctx: &ToolCtx,
|
||||
) -> Result<ExecToolCallOutput, ToolError> {
|
||||
let spec = Self::build_command_spec(req)?;
|
||||
let env = attempt
|
||||
|
||||
@@ -5,7 +5,11 @@ Executes shell requests under the orchestrator: asks for approval when needed,
|
||||
builds a CommandSpec, and runs it under the current SandboxAttempt.
|
||||
*/
|
||||
use crate::command_canonicalization::canonicalize_command_for_approval;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::is_likely_sandbox_denied;
|
||||
use crate::features::Feature;
|
||||
use crate::powershell::prefix_powershell_script_with_utf8;
|
||||
use crate::sandboxing::SandboxPermissions;
|
||||
@@ -26,19 +30,56 @@ use crate::tools::sandboxing::ToolCtx;
|
||||
use crate::tools::sandboxing::ToolError;
|
||||
use crate::tools::sandboxing::ToolRuntime;
|
||||
use crate::tools::sandboxing::with_cached_approval;
|
||||
use crate::zsh_exec_bridge::ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR;
|
||||
use codex_execpolicy::Decision;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_execpolicy::RuleMatch;
|
||||
use codex_network_proxy::NetworkProxy;
|
||||
use codex_protocol::config_types::WindowsSandboxLevel;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::ReviewDecision;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_shell_command::bash::parse_shell_lc_plain_commands;
|
||||
use codex_shell_command::bash::parse_shell_lc_single_command_prefix;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::core_shell_escalation::ShellActionProvider;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::core_shell_escalation::ShellPolicyFactory;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::escalate_protocol::EscalateAction;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::escalate_server::ExecParams;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::escalate_server::ExecResult;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::escalate_server::SandboxState;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::escalate_server::ShellCommandExecutor;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::escalate_server::run_escalate_server;
|
||||
#[cfg(unix)]
|
||||
use codex_shell_escalation::unix::stopwatch::Stopwatch;
|
||||
#[cfg(unix)]
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use futures::future::BoxFuture;
|
||||
use shlex::try_join as shlex_try_join;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
#[cfg(unix)]
|
||||
use std::sync::Arc;
|
||||
#[cfg(unix)]
|
||||
use std::time::Duration;
|
||||
#[cfg(unix)]
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ShellRequest {
|
||||
pub command: Vec<String>,
|
||||
pub cwd: PathBuf,
|
||||
pub timeout_ms: Option<u64>,
|
||||
pub env: std::collections::HashMap<String, String>,
|
||||
pub explicit_env_overrides: std::collections::HashMap<String, String>,
|
||||
pub env: HashMap<String, String>,
|
||||
pub explicit_env_overrides: HashMap<String, String>,
|
||||
pub network: Option<NetworkProxy>,
|
||||
pub sandbox_permissions: SandboxPermissions,
|
||||
pub justification: Option<String>,
|
||||
@@ -60,7 +101,7 @@ impl ShellRuntime {
|
||||
Self
|
||||
}
|
||||
|
||||
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
|
||||
fn stdout_stream(ctx: &ToolCtx) -> Option<crate::exec::StdoutStream> {
|
||||
Some(crate::exec::StdoutStream {
|
||||
sub_id: ctx.turn.sub_id.clone(),
|
||||
call_id: ctx.call_id.clone(),
|
||||
@@ -73,6 +114,7 @@ impl Sandboxable for ShellRuntime {
|
||||
fn sandbox_preference(&self) -> SandboxablePreference {
|
||||
SandboxablePreference::Auto
|
||||
}
|
||||
|
||||
fn escalate_on_failure(&self) -> bool {
|
||||
true
|
||||
}
|
||||
@@ -146,11 +188,214 @@ impl Approvable<ShellRequest> for ShellRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
struct CoreShellActionProvider {
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
session: std::sync::Arc<crate::codex::Session>,
|
||||
turn: std::sync::Arc<crate::codex::TurnContext>,
|
||||
call_id: String,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
sandbox_permissions: SandboxPermissions,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl CoreShellActionProvider {
|
||||
fn decision_driven_by_policy(matched_rules: &[RuleMatch], decision: Decision) -> bool {
|
||||
matched_rules.iter().any(|rule_match| {
|
||||
!matches!(rule_match, RuleMatch::HeuristicsRuleMatch { .. })
|
||||
&& rule_match.decision() == decision
|
||||
})
|
||||
}
|
||||
|
||||
async fn prompt(
|
||||
&self,
|
||||
command: &[String],
|
||||
workdir: &Path,
|
||||
stopwatch: &Stopwatch,
|
||||
) -> anyhow::Result<ReviewDecision> {
|
||||
let command = command.to_vec();
|
||||
let workdir = workdir.to_path_buf();
|
||||
let session = self.session.clone();
|
||||
let turn = self.turn.clone();
|
||||
let call_id = self.call_id.clone();
|
||||
Ok(stopwatch
|
||||
.pause_for(async move {
|
||||
session
|
||||
.request_command_approval(
|
||||
&turn, call_id, None, command, workdir, None, None, None,
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[async_trait::async_trait]
|
||||
impl ShellActionProvider for CoreShellActionProvider {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
file: &Path,
|
||||
argv: &[String],
|
||||
workdir: &Path,
|
||||
stopwatch: &Stopwatch,
|
||||
) -> anyhow::Result<EscalateAction> {
|
||||
let command = std::iter::once(file.to_string_lossy().to_string())
|
||||
.chain(argv.iter().cloned())
|
||||
.collect::<Vec<_>>();
|
||||
let (commands, used_complex_parsing) =
|
||||
if let Some(commands) = parse_shell_lc_plain_commands(&command) {
|
||||
(commands, false)
|
||||
} else if let Some(single_command) = parse_shell_lc_single_command_prefix(&command) {
|
||||
(vec![single_command], true)
|
||||
} else {
|
||||
(vec![command.clone()], false)
|
||||
};
|
||||
|
||||
let policy = self.policy.read().await;
|
||||
let fallback = |cmd: &[String]| {
|
||||
crate::exec_policy::render_decision_for_unmatched_command(
|
||||
self.approval_policy,
|
||||
&self.sandbox_policy,
|
||||
cmd,
|
||||
self.sandbox_permissions,
|
||||
used_complex_parsing,
|
||||
)
|
||||
};
|
||||
let evaluation = policy.check_multiple(commands.iter(), &fallback);
|
||||
let decision_driven_by_policy =
|
||||
Self::decision_driven_by_policy(&evaluation.matched_rules, evaluation.decision);
|
||||
let needs_escalation =
|
||||
self.sandbox_permissions.requires_escalated_permissions() || decision_driven_by_policy;
|
||||
|
||||
Ok(match evaluation.decision {
|
||||
Decision::Forbidden => EscalateAction::Deny {
|
||||
reason: Some("Execution forbidden by policy".to_string()),
|
||||
},
|
||||
Decision::Prompt => {
|
||||
if self.approval_policy == AskForApproval::Never {
|
||||
EscalateAction::Deny {
|
||||
reason: Some("Execution forbidden by policy".to_string()),
|
||||
}
|
||||
} else if decision_driven_by_policy {
|
||||
EscalateAction::Escalate
|
||||
} else {
|
||||
match self.prompt(&command, workdir, stopwatch).await? {
|
||||
ReviewDecision::Approved
|
||||
| ReviewDecision::ApprovedExecpolicyAmendment { .. }
|
||||
| ReviewDecision::ApprovedForSession => {
|
||||
if needs_escalation {
|
||||
EscalateAction::Escalate
|
||||
} else {
|
||||
EscalateAction::Run
|
||||
}
|
||||
}
|
||||
ReviewDecision::Denied => EscalateAction::Deny {
|
||||
reason: Some("User denied execution".to_string()),
|
||||
},
|
||||
ReviewDecision::Abort => EscalateAction::Deny {
|
||||
reason: Some("User cancelled execution".to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
Decision::Allow => EscalateAction::Run,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
struct CoreShellCommandExecutor;
|
||||
|
||||
#[cfg(unix)]
|
||||
#[async_trait::async_trait]
|
||||
impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
async fn run(
|
||||
&self,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
env: HashMap<String, String>,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &SandboxState,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let result = crate::exec::process_exec_tool_call(
|
||||
crate::exec::ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
expiration: crate::exec::ExecExpiration::Cancellation(cancel_rx),
|
||||
env,
|
||||
network: None,
|
||||
sandbox_permissions: SandboxPermissions::UseDefault,
|
||||
windows_sandbox_level: WindowsSandboxLevel::Disabled,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
&sandbox_state.sandbox_policy,
|
||||
&sandbox_state.sandbox_cwd,
|
||||
&sandbox_state.codex_linux_sandbox_exe,
|
||||
sandbox_state.use_linux_sandbox_bwrap,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ExecResult {
|
||||
exit_code: result.exit_code,
|
||||
output: result.aggregated_output.text,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn shell_execve_wrapper() -> anyhow::Result<PathBuf> {
|
||||
let exe = std::env::current_exe()?;
|
||||
exe.parent()
|
||||
.map(|parent| parent.join("codex-execve-wrapper"))
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to determine codex-execve-wrapper path"))
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn shell_exec_zsh_path(path: &AbsolutePathBuf) -> PathBuf {
|
||||
path.to_path_buf()
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn map_exec_result(
|
||||
sandbox: SandboxType,
|
||||
result: ExecResult,
|
||||
) -> Result<ExecToolCallOutput, ToolError> {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: result.exit_code,
|
||||
stdout: crate::exec::StreamOutput::new(result.output.clone()),
|
||||
stderr: crate::exec::StreamOutput::new(String::new()),
|
||||
aggregated_output: crate::exec::StreamOutput::new(result.output.clone()),
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
};
|
||||
|
||||
if result.timed_out {
|
||||
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Timeout {
|
||||
output: Box::new(output),
|
||||
})));
|
||||
}
|
||||
|
||||
if is_likely_sandbox_denied(sandbox, &output) {
|
||||
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
network_policy_decision: None,
|
||||
})));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
|
||||
fn network_approval_spec(
|
||||
&self,
|
||||
req: &ShellRequest,
|
||||
_ctx: &ToolCtx<'_>,
|
||||
_ctx: &ToolCtx,
|
||||
) -> Option<NetworkApprovalSpec> {
|
||||
req.network.as_ref()?;
|
||||
Some(NetworkApprovalSpec {
|
||||
@@ -163,17 +408,15 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
|
||||
&mut self,
|
||||
req: &ShellRequest,
|
||||
attempt: &SandboxAttempt<'_>,
|
||||
ctx: &ToolCtx<'_>,
|
||||
ctx: &ToolCtx,
|
||||
) -> Result<ExecToolCallOutput, ToolError> {
|
||||
let base_command = &req.command;
|
||||
let session_shell = ctx.session.user_shell();
|
||||
let command = maybe_wrap_shell_lc_with_snapshot(
|
||||
base_command,
|
||||
session_shell.as_ref(),
|
||||
&req.command,
|
||||
ctx.session.user_shell().as_ref(),
|
||||
&req.cwd,
|
||||
&req.explicit_env_overrides,
|
||||
);
|
||||
let command = if matches!(session_shell.shell_type, ShellType::PowerShell)
|
||||
let command = if matches!(ctx.session.user_shell().shell_type, ShellType::PowerShell)
|
||||
&& ctx.session.features().enabled(Feature::PowershellUtf8)
|
||||
{
|
||||
prefix_powershell_script_with_utf8(&command)
|
||||
@@ -181,21 +424,15 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
|
||||
command
|
||||
};
|
||||
|
||||
if ctx.session.features().enabled(Feature::ShellZshFork) {
|
||||
let wrapper_socket_path = ctx
|
||||
.session
|
||||
.services
|
||||
.zsh_exec_bridge
|
||||
.next_wrapper_socket_path();
|
||||
let mut zsh_fork_env = req.env.clone();
|
||||
zsh_fork_env.insert(
|
||||
ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR.to_string(),
|
||||
wrapper_socket_path.to_string_lossy().to_string(),
|
||||
);
|
||||
#[cfg(unix)]
|
||||
if let Some(shell_zsh_path) = ctx.session.services.shell_zsh_path.as_ref()
|
||||
&& ctx.session.features().enabled(Feature::ShellZshFork)
|
||||
&& matches!(ctx.session.user_shell().shell_type, ShellType::Zsh)
|
||||
{
|
||||
let spec = build_command_spec(
|
||||
&command,
|
||||
&req.cwd,
|
||||
&zsh_fork_env,
|
||||
&req.env,
|
||||
req.timeout_ms.into(),
|
||||
req.sandbox_permissions,
|
||||
req.justification.clone(),
|
||||
@@ -203,12 +440,52 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
|
||||
let env = attempt
|
||||
.env_for(spec, req.network.as_ref())
|
||||
.map_err(|err| ToolError::Codex(err.into()))?;
|
||||
return ctx
|
||||
.session
|
||||
.services
|
||||
.zsh_exec_bridge
|
||||
.execute_shell_request(&env, ctx.session, ctx.turn, &ctx.call_id)
|
||||
.await;
|
||||
let (_, args) = env
|
||||
.command
|
||||
.split_first()
|
||||
.ok_or_else(|| ToolError::Rejected("command args are empty".to_string()))?;
|
||||
let script = shlex_try_join(args.iter().map(String::as_str))
|
||||
.map_err(|err| ToolError::Rejected(format!("serialize shell script: {err}")))?;
|
||||
let effective_timeout = Duration::from_millis(
|
||||
req.timeout_ms
|
||||
.unwrap_or(crate::exec::DEFAULT_EXEC_COMMAND_TIMEOUT_MS),
|
||||
);
|
||||
let exec_policy = Arc::new(RwLock::new(
|
||||
ctx.session.services.exec_policy.current().as_ref().clone(),
|
||||
));
|
||||
let sandbox_state = SandboxState {
|
||||
sandbox_policy: ctx.turn.sandbox_policy.get().clone(),
|
||||
codex_linux_sandbox_exe: attempt.codex_linux_sandbox_exe.cloned(),
|
||||
sandbox_cwd: req.cwd.clone(),
|
||||
use_linux_sandbox_bwrap: attempt.use_linux_sandbox_bwrap,
|
||||
};
|
||||
let exec_result = run_escalate_server(
|
||||
ExecParams {
|
||||
command: script,
|
||||
workdir: req.cwd.to_string_lossy().to_string(),
|
||||
timeout_ms: Some(effective_timeout.as_millis() as u64),
|
||||
login: Some(false),
|
||||
},
|
||||
&sandbox_state,
|
||||
shell_exec_zsh_path(shell_zsh_path),
|
||||
shell_execve_wrapper().map_err(|err| ToolError::Rejected(format!("{err}")))?,
|
||||
exec_policy.clone(),
|
||||
ShellPolicyFactory::new(CoreShellActionProvider {
|
||||
policy: Arc::clone(&exec_policy),
|
||||
session: Arc::clone(&ctx.session),
|
||||
turn: Arc::clone(&ctx.turn),
|
||||
call_id: ctx.call_id.clone(),
|
||||
approval_policy: ctx.turn.approval_policy.value(),
|
||||
sandbox_policy: attempt.policy.clone(),
|
||||
sandbox_permissions: req.sandbox_permissions,
|
||||
}),
|
||||
effective_timeout,
|
||||
&CoreShellCommandExecutor,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| ToolError::Rejected(err.to_string()))?;
|
||||
|
||||
return map_exec_result(attempt.sandbox, exec_result);
|
||||
}
|
||||
|
||||
let spec = build_command_spec(
|
||||
|
||||
@@ -153,7 +153,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
|
||||
fn network_approval_spec(
|
||||
&self,
|
||||
req: &UnifiedExecRequest,
|
||||
_ctx: &ToolCtx<'_>,
|
||||
_ctx: &ToolCtx,
|
||||
) -> Option<NetworkApprovalSpec> {
|
||||
req.network.as_ref()?;
|
||||
Some(NetworkApprovalSpec {
|
||||
@@ -166,7 +166,7 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
|
||||
&mut self,
|
||||
req: &UnifiedExecRequest,
|
||||
attempt: &SandboxAttempt<'_>,
|
||||
ctx: &ToolCtx<'_>,
|
||||
ctx: &ToolCtx,
|
||||
) -> Result<UnifiedExecProcess, ToolError> {
|
||||
let base_command = &req.command;
|
||||
let session_shell = ctx.session.user_shell();
|
||||
|
||||
@@ -18,14 +18,14 @@ use codex_protocol::approvals::ExecPolicyAmendment;
|
||||
use codex_protocol::approvals::NetworkApprovalContext;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::ReviewDecision;
|
||||
use futures::Future;
|
||||
use futures::future::BoxFuture;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::hash::Hash;
|
||||
use std::path::Path;
|
||||
|
||||
use futures::Future;
|
||||
use futures::future::BoxFuture;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub(crate) struct ApprovalStore {
|
||||
@@ -267,9 +267,9 @@ pub(crate) trait Sandboxable {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ToolCtx<'a> {
|
||||
pub session: &'a Session,
|
||||
pub turn: &'a TurnContext,
|
||||
pub(crate) struct ToolCtx {
|
||||
pub session: Arc<Session>,
|
||||
pub turn: Arc<TurnContext>,
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
}
|
||||
@@ -281,7 +281,7 @@ pub(crate) enum ToolError {
|
||||
}
|
||||
|
||||
pub(crate) trait ToolRuntime<Req, Out>: Approvable<Req> + Sandboxable {
|
||||
fn network_approval_spec(&self, _req: &Req, _ctx: &ToolCtx<'_>) -> Option<NetworkApprovalSpec> {
|
||||
fn network_approval_spec(&self, _req: &Req, _ctx: &ToolCtx) -> Option<NetworkApprovalSpec> {
|
||||
None
|
||||
}
|
||||
|
||||
|
||||
@@ -594,8 +594,8 @@ impl UnifiedExecProcessManager {
|
||||
exec_approval_requirement,
|
||||
};
|
||||
let tool_ctx = ToolCtx {
|
||||
session: context.session.as_ref(),
|
||||
turn: context.turn.as_ref(),
|
||||
session: context.session.clone(),
|
||||
turn: context.turn.clone(),
|
||||
call_id: context.call_id.clone(),
|
||||
tool_name: "exec_command".to_string(),
|
||||
};
|
||||
@@ -604,7 +604,7 @@ impl UnifiedExecProcessManager {
|
||||
&mut runtime,
|
||||
&req,
|
||||
&tool_ctx,
|
||||
context.turn.as_ref(),
|
||||
&context.turn,
|
||||
context.turn.approval_policy.value(),
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -1,557 +0,0 @@
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::tools::sandboxing::ToolError;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[cfg(unix)]
|
||||
use crate::error::CodexErr;
|
||||
#[cfg(unix)]
|
||||
use crate::error::SandboxErr;
|
||||
#[cfg(unix)]
|
||||
use crate::protocol::EventMsg;
|
||||
#[cfg(unix)]
|
||||
use crate::protocol::ExecCommandOutputDeltaEvent;
|
||||
#[cfg(unix)]
|
||||
use crate::protocol::ExecOutputStream;
|
||||
#[cfg(unix)]
|
||||
use crate::protocol::ReviewDecision;
|
||||
#[cfg(unix)]
|
||||
use anyhow::Context as _;
|
||||
#[cfg(unix)]
|
||||
use codex_protocol::approvals::ExecPolicyAmendment;
|
||||
#[cfg(unix)]
|
||||
use codex_utils_pty::process_group::kill_child_process_group;
|
||||
#[cfg(unix)]
|
||||
use serde::Deserialize;
|
||||
#[cfg(unix)]
|
||||
use serde::Serialize;
|
||||
#[cfg(unix)]
|
||||
use std::io::Read;
|
||||
#[cfg(unix)]
|
||||
use std::io::Write;
|
||||
#[cfg(unix)]
|
||||
use std::time::Instant;
|
||||
#[cfg(unix)]
|
||||
use tokio::io::AsyncReadExt;
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixListener;
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
pub(crate) const ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR: &str =
|
||||
"CODEX_ZSH_EXEC_BRIDGE_WRAPPER_SOCKET";
|
||||
pub(crate) const ZSH_EXEC_WRAPPER_MODE_ENV_VAR: &str = "CODEX_ZSH_EXEC_WRAPPER_MODE";
|
||||
#[cfg(unix)]
|
||||
pub(crate) const EXEC_WRAPPER_ENV_VAR: &str = "EXEC_WRAPPER";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub(crate) struct ZshExecBridgeSessionState {
|
||||
pub(crate) initialized_session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct ZshExecBridge {
|
||||
zsh_path: Option<PathBuf>,
|
||||
state: Mutex<ZshExecBridgeSessionState>,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum WrapperIpcRequest {
|
||||
ExecRequest {
|
||||
request_id: String,
|
||||
file: String,
|
||||
argv: Vec<String>,
|
||||
cwd: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum WrapperIpcResponse {
|
||||
ExecResponse {
|
||||
request_id: String,
|
||||
action: WrapperExecAction,
|
||||
reason: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum WrapperExecAction {
|
||||
Run,
|
||||
Deny,
|
||||
}
|
||||
|
||||
impl ZshExecBridge {
|
||||
pub(crate) fn new(zsh_path: Option<PathBuf>, _codex_home: PathBuf) -> Self {
|
||||
Self {
|
||||
zsh_path,
|
||||
state: Mutex::new(ZshExecBridgeSessionState::default()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn initialize_for_session(&self, session_id: &str) {
|
||||
let mut state = self.state.lock().await;
|
||||
state.initialized_session_id = Some(session_id.to_string());
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&self) {
|
||||
let mut state = self.state.lock().await;
|
||||
state.initialized_session_id = None;
|
||||
}
|
||||
|
||||
pub(crate) fn next_wrapper_socket_path(&self) -> PathBuf {
|
||||
let socket_id = Uuid::new_v4().as_simple().to_string();
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let canonical_temp_dir = temp_dir.canonicalize().unwrap_or(temp_dir);
|
||||
canonical_temp_dir.join(format!("czs-{}.sock", &socket_id[..12]))
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
pub(crate) async fn execute_shell_request(
|
||||
&self,
|
||||
_req: &crate::sandboxing::ExecRequest,
|
||||
_session: &crate::codex::Session,
|
||||
_turn: &crate::codex::TurnContext,
|
||||
_call_id: &str,
|
||||
) -> Result<ExecToolCallOutput, ToolError> {
|
||||
let _ = &self.zsh_path;
|
||||
Err(ToolError::Rejected(
|
||||
"shell_zsh_fork is only supported on unix".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) async fn execute_shell_request(
|
||||
&self,
|
||||
req: &crate::sandboxing::ExecRequest,
|
||||
session: &crate::codex::Session,
|
||||
turn: &crate::codex::TurnContext,
|
||||
call_id: &str,
|
||||
) -> Result<ExecToolCallOutput, ToolError> {
|
||||
let zsh_path = self.zsh_path.clone().ok_or_else(|| {
|
||||
ToolError::Rejected(
|
||||
"shell_zsh_fork enabled, but zsh_path is not configured".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let command = req.command.clone();
|
||||
if command.is_empty() {
|
||||
return Err(ToolError::Rejected("command args are empty".to_string()));
|
||||
}
|
||||
|
||||
let wrapper_socket_path = req
|
||||
.env
|
||||
.get(ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR)
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| self.next_wrapper_socket_path());
|
||||
|
||||
let listener = {
|
||||
let _ = std::fs::remove_file(&wrapper_socket_path);
|
||||
UnixListener::bind(&wrapper_socket_path).map_err(|err| {
|
||||
ToolError::Rejected(format!(
|
||||
"bind wrapper socket at {}: {err}",
|
||||
wrapper_socket_path.display()
|
||||
))
|
||||
})?
|
||||
};
|
||||
|
||||
let wrapper_path = std::env::current_exe().map_err(|err| {
|
||||
ToolError::Rejected(format!("resolve current executable path: {err}"))
|
||||
})?;
|
||||
|
||||
let mut cmd = tokio::process::Command::new(&command[0]);
|
||||
#[cfg(unix)]
|
||||
if let Some(arg0) = &req.arg0 {
|
||||
cmd.arg0(arg0);
|
||||
}
|
||||
if command.len() > 1 {
|
||||
cmd.args(&command[1..]);
|
||||
}
|
||||
cmd.current_dir(&req.cwd);
|
||||
cmd.stdin(std::process::Stdio::null());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.env_clear();
|
||||
cmd.envs(&req.env);
|
||||
cmd.env(
|
||||
ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR,
|
||||
wrapper_socket_path.to_string_lossy().to_string(),
|
||||
);
|
||||
cmd.env(EXEC_WRAPPER_ENV_VAR, &wrapper_path);
|
||||
cmd.env(ZSH_EXEC_WRAPPER_MODE_ENV_VAR, "1");
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
ToolError::Rejected(format!(
|
||||
"failed to start zsh fork command {} with zsh_path {}: {err}",
|
||||
command[0],
|
||||
zsh_path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
let (stream_tx, mut stream_rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<(ExecOutputStream, Vec<u8>)>();
|
||||
|
||||
if let Some(mut out) = child.stdout.take() {
|
||||
let tx = stream_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = [0_u8; 8192];
|
||||
loop {
|
||||
let read = match out.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => n,
|
||||
Err(err) => {
|
||||
tracing::warn!("zsh fork stdout read error: {err}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
let _ = tx.send((ExecOutputStream::Stdout, buf[..read].to_vec()));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(mut err) = child.stderr.take() {
|
||||
let tx = stream_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = [0_u8; 8192];
|
||||
loop {
|
||||
let read = match err.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => n,
|
||||
Err(err) => {
|
||||
tracing::warn!("zsh fork stderr read error: {err}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
let _ = tx.send((ExecOutputStream::Stderr, buf[..read].to_vec()));
|
||||
}
|
||||
});
|
||||
}
|
||||
drop(stream_tx);
|
||||
|
||||
let mut stdout_bytes = Vec::new();
|
||||
let mut stderr_bytes = Vec::new();
|
||||
let mut child_exit = None;
|
||||
let mut timed_out = false;
|
||||
let mut stream_open = true;
|
||||
let mut user_rejected = false;
|
||||
let start = Instant::now();
|
||||
|
||||
let expiration = req.expiration.clone().wait();
|
||||
tokio::pin!(expiration);
|
||||
|
||||
while child_exit.is_none() || stream_open {
|
||||
tokio::select! {
|
||||
result = child.wait(), if child_exit.is_none() => {
|
||||
child_exit = Some(result.map_err(|err| ToolError::Rejected(format!("wait for zsh fork command exit: {err}")))?);
|
||||
}
|
||||
stream = stream_rx.recv(), if stream_open => {
|
||||
if let Some((output_stream, chunk)) = stream {
|
||||
match output_stream {
|
||||
ExecOutputStream::Stdout => stdout_bytes.extend_from_slice(&chunk),
|
||||
ExecOutputStream::Stderr => stderr_bytes.extend_from_slice(&chunk),
|
||||
}
|
||||
session
|
||||
.send_event(
|
||||
turn,
|
||||
EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
|
||||
call_id: call_id.to_string(),
|
||||
stream: output_stream,
|
||||
chunk,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
stream_open = false;
|
||||
}
|
||||
}
|
||||
accept_result = listener.accept(), if child_exit.is_none() => {
|
||||
let (stream, _) = accept_result.map_err(|err| {
|
||||
ToolError::Rejected(format!("failed to accept wrapper request: {err}"))
|
||||
})?;
|
||||
if self
|
||||
.handle_wrapper_request(stream, req.justification.clone(), session, turn, call_id)
|
||||
.await?
|
||||
{
|
||||
user_rejected = true;
|
||||
}
|
||||
}
|
||||
_ = &mut expiration, if child_exit.is_none() => {
|
||||
timed_out = true;
|
||||
kill_child_process_group(&mut child).map_err(|err| {
|
||||
ToolError::Rejected(format!("kill zsh fork command process group: {err}"))
|
||||
})?;
|
||||
child.start_kill().map_err(|err| {
|
||||
ToolError::Rejected(format!("kill zsh fork command process: {err}"))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = std::fs::remove_file(&wrapper_socket_path);
|
||||
|
||||
let status = child_exit.ok_or_else(|| {
|
||||
ToolError::Rejected("zsh fork command did not return exit status".to_string())
|
||||
})?;
|
||||
|
||||
if user_rejected {
|
||||
return Err(ToolError::Rejected("rejected by user".to_string()));
|
||||
}
|
||||
|
||||
let stdout_text = crate::text_encoding::bytes_to_string_smart(&stdout_bytes);
|
||||
let stderr_text = crate::text_encoding::bytes_to_string_smart(&stderr_bytes);
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: status.code().unwrap_or(-1),
|
||||
stdout: crate::exec::StreamOutput::new(stdout_text.clone()),
|
||||
stderr: crate::exec::StreamOutput::new(stderr_text.clone()),
|
||||
aggregated_output: crate::exec::StreamOutput::new(format!(
|
||||
"{stdout_text}{stderr_text}"
|
||||
)),
|
||||
duration: start.elapsed(),
|
||||
timed_out,
|
||||
};
|
||||
|
||||
Self::map_exec_result(req.sandbox, output)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn handle_wrapper_request(
|
||||
&self,
|
||||
mut stream: UnixStream,
|
||||
approval_reason: Option<String>,
|
||||
session: &crate::codex::Session,
|
||||
turn: &crate::codex::TurnContext,
|
||||
call_id: &str,
|
||||
) -> Result<bool, ToolError> {
|
||||
let mut request_buf = Vec::new();
|
||||
stream.read_to_end(&mut request_buf).await.map_err(|err| {
|
||||
ToolError::Rejected(format!("read wrapper request from socket: {err}"))
|
||||
})?;
|
||||
let request_line = String::from_utf8(request_buf).map_err(|err| {
|
||||
ToolError::Rejected(format!("decode wrapper request as utf-8: {err}"))
|
||||
})?;
|
||||
let request = parse_wrapper_request_line(request_line.trim())?;
|
||||
|
||||
let (request_id, file, argv, cwd) = match request {
|
||||
WrapperIpcRequest::ExecRequest {
|
||||
request_id,
|
||||
file,
|
||||
argv,
|
||||
cwd,
|
||||
} => (request_id, file, argv, cwd),
|
||||
};
|
||||
|
||||
let command_for_approval = if argv.is_empty() {
|
||||
vec![file.clone()]
|
||||
} else {
|
||||
argv.clone()
|
||||
};
|
||||
|
||||
let approval_id = Uuid::new_v4().to_string();
|
||||
let decision = session
|
||||
.request_command_approval(
|
||||
turn,
|
||||
call_id.to_string(),
|
||||
Some(approval_id),
|
||||
command_for_approval,
|
||||
PathBuf::from(cwd),
|
||||
approval_reason,
|
||||
None,
|
||||
None::<ExecPolicyAmendment>,
|
||||
)
|
||||
.await;
|
||||
|
||||
let (action, reason, user_rejected) = match decision {
|
||||
ReviewDecision::Approved
|
||||
| ReviewDecision::ApprovedForSession
|
||||
| ReviewDecision::ApprovedExecpolicyAmendment { .. } => {
|
||||
(WrapperExecAction::Run, None, false)
|
||||
}
|
||||
ReviewDecision::Denied => (
|
||||
WrapperExecAction::Deny,
|
||||
Some("command denied by host approval policy".to_string()),
|
||||
true,
|
||||
),
|
||||
ReviewDecision::Abort => (
|
||||
WrapperExecAction::Deny,
|
||||
Some("command aborted by host approval policy".to_string()),
|
||||
true,
|
||||
),
|
||||
};
|
||||
|
||||
write_json_line(
|
||||
&mut stream,
|
||||
&WrapperIpcResponse::ExecResponse {
|
||||
request_id,
|
||||
action,
|
||||
reason,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(user_rejected)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn map_exec_result(
|
||||
sandbox: crate::exec::SandboxType,
|
||||
output: ExecToolCallOutput,
|
||||
) -> Result<ExecToolCallOutput, ToolError> {
|
||||
if output.timed_out {
|
||||
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Timeout {
|
||||
output: Box::new(output),
|
||||
})));
|
||||
}
|
||||
|
||||
if crate::exec::is_likely_sandbox_denied(sandbox, &output) {
|
||||
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
network_policy_decision: None,
|
||||
})));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maybe_run_zsh_exec_wrapper_mode() -> anyhow::Result<bool> {
|
||||
if std::env::var_os(ZSH_EXEC_WRAPPER_MODE_ENV_VAR).is_none() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
run_exec_wrapper_mode()?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn run_exec_wrapper_mode() -> anyhow::Result<()> {
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
anyhow::bail!("zsh exec wrapper mode is only supported on unix");
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::net::UnixStream as StdUnixStream;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
anyhow::bail!("exec wrapper mode requires target executable path");
|
||||
}
|
||||
let file = args[1].clone();
|
||||
let argv = if args.len() > 2 {
|
||||
args[2..].to_vec()
|
||||
} else {
|
||||
vec![file.clone()]
|
||||
};
|
||||
let cwd = std::env::current_dir()
|
||||
.context("resolve wrapper cwd")?
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
let socket_path = std::env::var(ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR)
|
||||
.context("missing wrapper socket path env var")?;
|
||||
|
||||
let request_id = Uuid::new_v4().to_string();
|
||||
let request = WrapperIpcRequest::ExecRequest {
|
||||
request_id: request_id.clone(),
|
||||
file: file.clone(),
|
||||
argv: argv.clone(),
|
||||
cwd,
|
||||
};
|
||||
let mut stream = StdUnixStream::connect(&socket_path)
|
||||
.with_context(|| format!("connect to wrapper socket at {socket_path}"))?;
|
||||
let encoded = serde_json::to_string(&request).context("serialize wrapper request")?;
|
||||
stream
|
||||
.write_all(encoded.as_bytes())
|
||||
.context("write wrapper request")?;
|
||||
stream
|
||||
.write_all(b"\n")
|
||||
.context("write wrapper request newline")?;
|
||||
stream
|
||||
.shutdown(std::net::Shutdown::Write)
|
||||
.context("shutdown wrapper write")?;
|
||||
|
||||
let mut response_buf = String::new();
|
||||
stream
|
||||
.read_to_string(&mut response_buf)
|
||||
.context("read wrapper response")?;
|
||||
let response: WrapperIpcResponse =
|
||||
serde_json::from_str(response_buf.trim()).context("parse wrapper response")?;
|
||||
|
||||
let (response_request_id, action, reason) = match response {
|
||||
WrapperIpcResponse::ExecResponse {
|
||||
request_id,
|
||||
action,
|
||||
reason,
|
||||
} => (request_id, action, reason),
|
||||
};
|
||||
if response_request_id != request_id {
|
||||
anyhow::bail!(
|
||||
"wrapper response request_id mismatch: expected {request_id}, got {response_request_id}"
|
||||
);
|
||||
}
|
||||
|
||||
if action == WrapperExecAction::Deny {
|
||||
if let Some(reason) = reason {
|
||||
tracing::warn!("execution denied: {reason}");
|
||||
} else {
|
||||
tracing::warn!("execution denied");
|
||||
}
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let mut command = std::process::Command::new(&file);
|
||||
if argv.len() > 1 {
|
||||
command.args(&argv[1..]);
|
||||
}
|
||||
command.env_remove(ZSH_EXEC_WRAPPER_MODE_ENV_VAR);
|
||||
command.env_remove(ZSH_EXEC_BRIDGE_WRAPPER_SOCKET_ENV_VAR);
|
||||
command.env_remove(EXEC_WRAPPER_ENV_VAR);
|
||||
let status = command.status().context("spawn wrapped executable")?;
|
||||
std::process::exit(status.code().unwrap_or(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn parse_wrapper_request_line(request_line: &str) -> Result<WrapperIpcRequest, ToolError> {
|
||||
serde_json::from_str(request_line)
|
||||
.map_err(|err| ToolError::Rejected(format!("parse wrapper request payload: {err}")))
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn write_json_line<W: tokio::io::AsyncWrite + Unpin, T: Serialize>(
|
||||
writer: &mut W,
|
||||
message: &T,
|
||||
) -> Result<(), ToolError> {
|
||||
let encoded = serde_json::to_string(message)
|
||||
.map_err(|err| ToolError::Rejected(format!("serialize wrapper message: {err}")))?;
|
||||
tokio::io::AsyncWriteExt::write_all(writer, encoded.as_bytes())
|
||||
.await
|
||||
.map_err(|err| ToolError::Rejected(format!("write wrapper message: {err}")))?;
|
||||
tokio::io::AsyncWriteExt::write_all(writer, b"\n")
|
||||
.await
|
||||
.map_err(|err| ToolError::Rejected(format!("write wrapper newline: {err}")))?;
|
||||
tokio::io::AsyncWriteExt::flush(writer)
|
||||
.await
|
||||
.map_err(|err| ToolError::Rejected(format!("flush wrapper message: {err}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(test, unix))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_wrapper_request_line_rejects_malformed_json() {
|
||||
let err = parse_wrapper_request_line("this-is-not-json").unwrap_err();
|
||||
let ToolError::Rejected(message) = err else {
|
||||
panic!("expected ToolError::Rejected");
|
||||
};
|
||||
assert!(message.starts_with("parse wrapper request payload:"));
|
||||
}
|
||||
}
|
||||
@@ -32,7 +32,16 @@ codex-core = { workspace = true }
|
||||
codex-execpolicy = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-shell-command = { workspace = true }
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
codex-shell-escalation = { workspace = true }
|
||||
||||||| base
|
||||
libc = { workspace = true }
|
||||
path-absolutize = { workspace = true }
|
||||
=======
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
codex-shell-escalation = { workspace = true }
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
rmcp = { workspace = true, default-features = false, features = [
|
||||
"auth",
|
||||
"elicitation",
|
||||
@@ -50,7 +59,22 @@ schemars = { version = "1.2.1" }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }
|
||||
||||||| base
|
||||
socket2 = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
=======
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }
|
||||
tokio-util = { workspace = true }
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] }
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ use codex_execpolicy::Decision;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_execpolicy::RuleMatch;
|
||||
use codex_shell_command::is_dangerous_command::command_might_be_dangerous;
|
||||
use codex_shell_escalation as shell_escalation;
|
||||
use codex_shell_escalation::unix::escalate_client::run;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
@@ -160,7 +160,7 @@ pub async fn main_execve_wrapper() -> anyhow::Result<()> {
|
||||
.init();
|
||||
|
||||
let ExecveWrapperCli { file, argv } = ExecveWrapperCli::parse();
|
||||
let exit_code = shell_escalation::run(file, argv).await?;
|
||||
let exit_code = run(file, argv).await?;
|
||||
std::process::exit(exit_code);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,19 @@ use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use codex_core::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
use codex_core::MCP_SANDBOX_STATE_METHOD;
|
||||
use codex_core::SandboxState;
|
||||
use codex_core::SandboxState as CoreSandboxState;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_protocol::config_types::WindowsSandboxLevel;
|
||||
use codex_protocol::models::SandboxPermissions as ProtocolSandboxPermissions;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_shell_escalation::EscalationPolicyFactory;
|
||||
use codex_shell_escalation::run_escalate_server;
|
||||
use codex_shell_escalation::unix::escalate_server::EscalationPolicyFactory;
|
||||
use codex_shell_escalation::unix::escalate_server::ExecParams as ShellExecParams;
|
||||
use codex_shell_escalation::unix::escalate_server::ExecResult as ShellExecResult;
|
||||
use codex_shell_escalation::unix::escalate_server::SandboxState as ShellEscalationSandboxState;
|
||||
use codex_shell_escalation::unix::escalate_server::ShellCommandExecutor;
|
||||
use codex_shell_escalation::unix::escalate_server::run_escalate_server;
|
||||
use codex_shell_escalation::unix::stopwatch::Stopwatch;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::RoleServer;
|
||||
use rmcp::ServerHandler;
|
||||
@@ -27,7 +35,9 @@ use rmcp::tool_handler;
|
||||
use rmcp::tool_router;
|
||||
use rmcp::transport::stdio;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::unix::mcp_escalation_policy::McpEscalationPolicy;
|
||||
|
||||
@@ -50,8 +60,8 @@ pub struct ExecResult {
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
impl From<codex_shell_escalation::ExecResult> for ExecResult {
|
||||
fn from(result: codex_shell_escalation::ExecResult) -> Self {
|
||||
impl From<ShellExecResult> for ExecResult {
|
||||
fn from(result: ShellExecResult) -> Self {
|
||||
Self {
|
||||
exit_code: result.exit_code,
|
||||
output: result.output,
|
||||
@@ -68,7 +78,7 @@ pub struct ExecTool {
|
||||
execve_wrapper: PathBuf,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
preserve_program_paths: bool,
|
||||
sandbox_state: Arc<RwLock<Option<SandboxState>>>,
|
||||
sandbox_state: Arc<RwLock<Option<CoreSandboxState>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize, rmcp::schemars::JsonSchema)]
|
||||
@@ -83,7 +93,7 @@ pub struct ExecParams {
|
||||
pub login: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ExecParams> for codex_shell_escalation::ExecParams {
|
||||
impl From<ExecParams> for ShellExecParams {
|
||||
fn from(inner: ExecParams) -> Self {
|
||||
Self {
|
||||
command: inner.command,
|
||||
@@ -99,14 +109,51 @@ struct McpEscalationPolicyFactory {
|
||||
preserve_program_paths: bool,
|
||||
}
|
||||
|
||||
struct McpShellCommandExecutor;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ShellCommandExecutor for McpShellCommandExecutor {
|
||||
async fn run(
|
||||
&self,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
env: HashMap<String, String>,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &ShellEscalationSandboxState,
|
||||
) -> anyhow::Result<ShellExecResult> {
|
||||
let result = process_exec_tool_call(
|
||||
codex_core::exec::ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx),
|
||||
env,
|
||||
network: None,
|
||||
sandbox_permissions: ProtocolSandboxPermissions::UseDefault,
|
||||
windows_sandbox_level: WindowsSandboxLevel::Disabled,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
&sandbox_state.sandbox_policy,
|
||||
&sandbox_state.sandbox_cwd,
|
||||
&sandbox_state.codex_linux_sandbox_exe,
|
||||
sandbox_state.use_linux_sandbox_bwrap,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ShellExecResult {
|
||||
exit_code: result.exit_code,
|
||||
output: result.aggregated_output.text,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EscalationPolicyFactory for McpEscalationPolicyFactory {
|
||||
type Policy = McpEscalationPolicy;
|
||||
|
||||
fn create_policy(
|
||||
&self,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
stopwatch: codex_shell_escalation::Stopwatch,
|
||||
) -> Self::Policy {
|
||||
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy {
|
||||
McpEscalationPolicy::new(
|
||||
policy,
|
||||
self.context.clone(),
|
||||
@@ -151,15 +198,21 @@ impl ExecTool {
|
||||
.read()
|
||||
.await
|
||||
.clone()
|
||||
.unwrap_or_else(|| SandboxState {
|
||||
.unwrap_or_else(|| CoreSandboxState {
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
sandbox_cwd: PathBuf::from(¶ms.workdir),
|
||||
use_linux_sandbox_bwrap: false,
|
||||
});
|
||||
let shell_sandbox_state = ShellEscalationSandboxState {
|
||||
sandbox_policy: sandbox_state.sandbox_policy.clone(),
|
||||
codex_linux_sandbox_exe: sandbox_state.codex_linux_sandbox_exe.clone(),
|
||||
sandbox_cwd: sandbox_state.sandbox_cwd.clone(),
|
||||
use_linux_sandbox_bwrap: sandbox_state.use_linux_sandbox_bwrap,
|
||||
};
|
||||
let result = run_escalate_server(
|
||||
params.into(),
|
||||
&sandbox_state,
|
||||
&shell_sandbox_state,
|
||||
&self.bash_path,
|
||||
&self.execve_wrapper,
|
||||
self.policy.clone(),
|
||||
@@ -168,6 +221,7 @@ impl ExecTool {
|
||||
preserve_program_paths: self.preserve_program_paths,
|
||||
},
|
||||
effective_timeout,
|
||||
&McpShellCommandExecutor,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
|
||||
@@ -236,7 +290,7 @@ impl ServerHandler for ExecTool {
|
||||
));
|
||||
};
|
||||
|
||||
let Ok(sandbox_state) = serde_json::from_value::<SandboxState>(params.clone()) else {
|
||||
let Ok(sandbox_state) = serde_json::from_value::<CoreSandboxState>(params.clone()) else {
|
||||
return Err(McpError::invalid_params(
|
||||
"failed to deserialize sandbox state".to_string(),
|
||||
Some(params),
|
||||
|
||||
@@ -2,9 +2,9 @@ use std::path::Path;
|
||||
|
||||
use codex_core::sandboxing::SandboxPermissions;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_shell_escalation::EscalateAction;
|
||||
use codex_shell_escalation::EscalationPolicy;
|
||||
use codex_shell_escalation::Stopwatch;
|
||||
use codex_shell_escalation::unix::escalate_protocol::EscalateAction;
|
||||
use codex_shell_escalation::unix::escalation_policy::EscalationPolicy;
|
||||
use codex_shell_escalation::unix::stopwatch::Stopwatch;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::RoleServer;
|
||||
use rmcp::model::CreateElicitationRequestParams;
|
||||
|
||||
@@ -7,20 +7,21 @@ license.workspace = true
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
codex-execpolicy = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
path-absolutize = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
socket2 = { workspace = true }
|
||||
socket2 = { workspace = true, features = ["all"] }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"net",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
"time",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
#[cfg(unix)]
|
||||
mod unix {
|
||||
mod escalate_client;
|
||||
@@ -19,3 +20,11 @@ mod unix {
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
||||||| base
|
||||
=======
|
||||
#[cfg(unix)]
|
||||
pub mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
@@ -69,3 +70,77 @@ impl EscalationPolicyFactory for ShellPolicyFactory {
|
||||
}
|
||||
}
|
||||
}
|
||||
||||||| base
|
||||
=======
|
||||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::escalate_protocol::EscalateAction;
|
||||
use crate::escalation_policy::EscalationPolicy;
|
||||
use crate::stopwatch::Stopwatch;
|
||||
use crate::unix::escalate_server::EscalationPolicyFactory;
|
||||
use codex_execpolicy::Policy;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ShellActionProvider: Send + Sync {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
file: &Path,
|
||||
argv: &[String],
|
||||
workdir: &Path,
|
||||
stopwatch: &Stopwatch,
|
||||
) -> anyhow::Result<EscalateAction>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ShellPolicyFactory {
|
||||
provider: Arc<dyn ShellActionProvider>,
|
||||
}
|
||||
|
||||
impl ShellPolicyFactory {
|
||||
pub fn new<P>(provider: P) -> Self
|
||||
where
|
||||
P: ShellActionProvider + 'static,
|
||||
{
|
||||
Self {
|
||||
provider: Arc::new(provider),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_provider(provider: Arc<dyn ShellActionProvider>) -> Self {
|
||||
Self { provider }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ShellEscalationPolicy {
|
||||
provider: Arc<dyn ShellActionProvider>,
|
||||
stopwatch: Stopwatch,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EscalationPolicy for ShellEscalationPolicy {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
file: &Path,
|
||||
argv: &[String],
|
||||
workdir: &Path,
|
||||
) -> anyhow::Result<EscalateAction> {
|
||||
self.provider
|
||||
.determine_action(file, argv, workdir, &self.stopwatch)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl EscalationPolicyFactory for ShellPolicyFactory {
|
||||
type Policy = ShellEscalationPolicy;
|
||||
|
||||
fn create_policy(&self, _policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy {
|
||||
ShellEscalationPolicy {
|
||||
provider: Arc::clone(&self.provider),
|
||||
stopwatch,
|
||||
}
|
||||
}
|
||||
}
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
use std::io;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd as _;
|
||||
@@ -111,3 +112,113 @@ pub async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32> {
|
||||
}
|
||||
}
|
||||
}
|
||||
||||||| base
|
||||
=======
|
||||
use std::io;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd as _;
|
||||
use std::os::fd::OwnedFd;
|
||||
|
||||
use anyhow::Context as _;
|
||||
|
||||
use crate::unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
use crate::unix::escalate_protocol::EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::unix::escalate_protocol::EscalateAction;
|
||||
use crate::unix::escalate_protocol::EscalateRequest;
|
||||
use crate::unix::escalate_protocol::EscalateResponse;
|
||||
use crate::unix::escalate_protocol::LEGACY_BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::unix::escalate_protocol::SuperExecMessage;
|
||||
use crate::unix::escalate_protocol::SuperExecResult;
|
||||
use crate::unix::socket::AsyncDatagramSocket;
|
||||
use crate::unix::socket::AsyncSocket;
|
||||
|
||||
fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
|
||||
// TODO: we should defensively require only calling this once, since AsyncSocket will take ownership of the fd.
|
||||
let client_fd = std::env::var(ESCALATE_SOCKET_ENV_VAR)?.parse::<i32>()?;
|
||||
if client_fd < 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"{ESCALATE_SOCKET_ENV_VAR} is not a valid file descriptor: {client_fd}"
|
||||
));
|
||||
}
|
||||
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
|
||||
}
|
||||
|
||||
pub async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32> {
|
||||
let handshake_client = get_escalate_client()?;
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
const HANDSHAKE_MESSAGE: [u8; 1] = [0];
|
||||
handshake_client
|
||||
.send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()])
|
||||
.await
|
||||
.context("failed to send handshake datagram")?;
|
||||
let env = std::env::vars()
|
||||
.filter(|(k, _)| {
|
||||
!matches!(
|
||||
k.as_str(),
|
||||
ESCALATE_SOCKET_ENV_VAR | EXEC_WRAPPER_ENV_VAR | LEGACY_BASH_EXEC_WRAPPER_ENV_VAR
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: file.clone().into(),
|
||||
argv: argv.clone(),
|
||||
workdir: std::env::current_dir()?,
|
||||
env,
|
||||
})
|
||||
.await
|
||||
.context("failed to send EscalateRequest")?;
|
||||
let message = client
|
||||
.receive::<EscalateResponse>()
|
||||
.await
|
||||
.context("failed to receive EscalateResponse")?;
|
||||
match message.action {
|
||||
EscalateAction::Escalate => {
|
||||
// TODO: maybe we should send ALL open FDs (except the escalate client)?
|
||||
let fds_to_send = [
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
|
||||
];
|
||||
|
||||
// TODO: also forward signals over the super-exec socket
|
||||
|
||||
client
|
||||
.send_with_fds(
|
||||
SuperExecMessage {
|
||||
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
|
||||
},
|
||||
&fds_to_send,
|
||||
)
|
||||
.await
|
||||
.context("failed to send SuperExecMessage")?;
|
||||
let SuperExecResult { exit_code } = client.receive::<SuperExecResult>().await?;
|
||||
Ok(exit_code)
|
||||
}
|
||||
EscalateAction::Run => {
|
||||
// We avoid std::process::Command here because we want to be as transparent as
|
||||
// possible. std::os::unix::process::CommandExt has .exec() but it does some funky
|
||||
// stuff with signal masks and dup2() on its standard FDs, which we don't want.
|
||||
use std::ffi::CString;
|
||||
let file = CString::new(file).context("NUL in file")?;
|
||||
|
||||
let argv_cstrs: Vec<CString> = argv
|
||||
.iter()
|
||||
.map(|s| CString::new(s.as_str()).context("NUL in argv"))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut argv: Vec<*const libc::c_char> =
|
||||
argv_cstrs.iter().map(|s| s.as_ptr()).collect();
|
||||
argv.push(std::ptr::null());
|
||||
|
||||
let err = unsafe {
|
||||
libc::execv(file.as_ptr(), argv.as_ptr());
|
||||
std::io::Error::last_os_error()
|
||||
};
|
||||
|
||||
Err(err.into())
|
||||
}
|
||||
EscalateAction::Deny { .. } => Ok(1),
|
||||
}
|
||||
}
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
use std::collections::HashMap;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::path::Path;
|
||||
@@ -375,3 +376,384 @@ mod tests {
|
||||
server_task.await?
|
||||
}
|
||||
}
|
||||
||||||| base
|
||||
=======
|
||||
use std::collections::HashMap;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use path_absolutize::Absolutize as _;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
use crate::unix::escalate_protocol::EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::unix::escalate_protocol::EscalateAction;
|
||||
use crate::unix::escalate_protocol::EscalateRequest;
|
||||
use crate::unix::escalate_protocol::EscalateResponse;
|
||||
use crate::unix::escalate_protocol::LEGACY_BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::unix::escalate_protocol::SuperExecMessage;
|
||||
use crate::unix::escalate_protocol::SuperExecResult;
|
||||
use crate::unix::escalation_policy::EscalationPolicy;
|
||||
use crate::unix::socket::AsyncDatagramSocket;
|
||||
use crate::unix::socket::AsyncSocket;
|
||||
use crate::unix::stopwatch::Stopwatch;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SandboxState {
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub sandbox_cwd: PathBuf,
|
||||
pub use_linux_sandbox_bwrap: bool,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait ShellCommandExecutor: Send + Sync {
|
||||
async fn run(
|
||||
&self,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
env: HashMap<String, String>,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &SandboxState,
|
||||
) -> anyhow::Result<ExecResult>;
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, serde::Serialize)]
|
||||
pub struct ExecParams {
|
||||
/// The bash string to execute.
|
||||
pub command: String,
|
||||
/// The working directory to execute the command in. Must be an absolute path.
|
||||
pub workdir: String,
|
||||
/// The timeout for the command in milliseconds.
|
||||
pub timeout_ms: Option<u64>,
|
||||
/// Launch Bash with -lc instead of -c: defaults to true.
|
||||
pub login: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ExecResult {
|
||||
pub exit_code: i32,
|
||||
pub output: String,
|
||||
pub duration: Duration,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub struct EscalateServer {
|
||||
bash_path: PathBuf,
|
||||
execve_wrapper: PathBuf,
|
||||
policy: Arc<dyn EscalationPolicy>,
|
||||
}
|
||||
|
||||
impl EscalateServer {
|
||||
pub fn new<P>(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self
|
||||
where
|
||||
P: EscalationPolicy + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
bash_path,
|
||||
execve_wrapper,
|
||||
policy: Arc::new(policy),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn exec(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &SandboxState,
|
||||
command_executor: &dyn ShellCommandExecutor,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
|
||||
let client_socket = escalate_client.into_inner();
|
||||
let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone()));
|
||||
let mut env = std::env::vars().collect::<HashMap<String, String>>();
|
||||
env.insert(
|
||||
ESCALATE_SOCKET_ENV_VAR.to_string(),
|
||||
client_socket.as_raw_fd().to_string(),
|
||||
);
|
||||
env.insert(
|
||||
EXEC_WRAPPER_ENV_VAR.to_string(),
|
||||
self.execve_wrapper.to_string_lossy().to_string(),
|
||||
);
|
||||
env.insert(
|
||||
LEGACY_BASH_EXEC_WRAPPER_ENV_VAR.to_string(),
|
||||
self.execve_wrapper.to_string_lossy().to_string(),
|
||||
);
|
||||
|
||||
let command = vec![
|
||||
self.bash_path.to_string_lossy().to_string(),
|
||||
if params.login == Some(false) {
|
||||
"-c".to_string()
|
||||
} else {
|
||||
"-lc".to_string()
|
||||
},
|
||||
params.command,
|
||||
];
|
||||
let result = command_executor
|
||||
.run(
|
||||
command,
|
||||
PathBuf::from(¶ms.workdir),
|
||||
env,
|
||||
cancel_rx,
|
||||
sandbox_state,
|
||||
)
|
||||
.await?;
|
||||
escalate_task.abort();
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Factory for creating escalation policy instances for a single shell run.
|
||||
pub trait EscalationPolicyFactory {
|
||||
type Policy: EscalationPolicy + Send + Sync + 'static;
|
||||
|
||||
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run_escalate_server(
|
||||
exec_params: ExecParams,
|
||||
sandbox_state: &SandboxState,
|
||||
shell_program: impl AsRef<Path>,
|
||||
execve_wrapper: impl AsRef<Path>,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
escalation_policy_factory: impl EscalationPolicyFactory,
|
||||
effective_timeout: Duration,
|
||||
command_executor: &dyn ShellCommandExecutor,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let stopwatch = Stopwatch::new(effective_timeout);
|
||||
let cancel_token = stopwatch.cancellation_token();
|
||||
let escalate_server = EscalateServer::new(
|
||||
shell_program.as_ref().to_path_buf(),
|
||||
execve_wrapper.as_ref().to_path_buf(),
|
||||
escalation_policy_factory.create_policy(policy, stopwatch),
|
||||
);
|
||||
|
||||
escalate_server
|
||||
.exec(exec_params, cancel_token, sandbox_state, command_executor)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn escalate_task(
|
||||
socket: AsyncDatagramSocket,
|
||||
policy: Arc<dyn EscalationPolicy>,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let (_, mut fds) = socket.receive_with_fds().await?;
|
||||
if fds.len() != 1 {
|
||||
tracing::error!("expected 1 fd in datagram handshake, got {}", fds.len());
|
||||
continue;
|
||||
}
|
||||
let stream_socket = AsyncSocket::from_fd(fds.remove(0))?;
|
||||
let policy = policy.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await {
|
||||
tracing::error!("escalate session failed: {err:?}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_escalate_session_with_policy(
|
||||
socket: AsyncSocket,
|
||||
policy: Arc<dyn EscalationPolicy>,
|
||||
) -> anyhow::Result<()> {
|
||||
let EscalateRequest {
|
||||
file,
|
||||
argv,
|
||||
workdir,
|
||||
env,
|
||||
} = socket.receive::<EscalateRequest>().await?;
|
||||
let file = PathBuf::from(&file).absolutize()?.into_owned();
|
||||
let workdir = PathBuf::from(&workdir).absolutize()?.into_owned();
|
||||
let action = policy
|
||||
.determine_action(file.as_path(), &argv, &workdir)
|
||||
.await
|
||||
.context("failed to determine escalation action")?;
|
||||
|
||||
tracing::debug!("decided {action:?} for {file:?} {argv:?} {workdir:?}");
|
||||
|
||||
match action {
|
||||
EscalateAction::Run => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Run,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
EscalateAction::Escalate => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
})
|
||||
.await?;
|
||||
let (msg, fds) = socket
|
||||
.receive_with_fds::<SuperExecMessage>()
|
||||
.await
|
||||
.context("failed to receive SuperExecMessage")?;
|
||||
if fds.len() != msg.fds.len() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"mismatched number of fds in SuperExecMessage: {} in the message, {} from the control message",
|
||||
msg.fds.len(),
|
||||
fds.len()
|
||||
));
|
||||
}
|
||||
|
||||
if msg
|
||||
.fds
|
||||
.iter()
|
||||
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"overlapping fds not yet supported in SuperExecMessage"
|
||||
));
|
||||
}
|
||||
|
||||
let mut command = Command::new(file);
|
||||
command
|
||||
.args(&argv[1..])
|
||||
.arg0(argv[0].clone())
|
||||
.envs(&env)
|
||||
.current_dir(&workdir)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
unsafe {
|
||||
command.pre_exec(move || {
|
||||
for (dst_fd, src_fd) in msg.fds.iter().zip(&fds) {
|
||||
libc::dup2(src_fd.as_raw_fd(), *dst_fd);
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
let mut child = command.spawn()?;
|
||||
let exit_status = child.wait().await?;
|
||||
socket
|
||||
.send(SuperExecResult {
|
||||
exit_code: exit_status.code().unwrap_or(127),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
EscalateAction::Deny { reason } => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Deny { reason },
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct DeterministicEscalationPolicy {
|
||||
action: EscalateAction,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl EscalationPolicy for DeterministicEscalationPolicy {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
_file: &Path,
|
||||
_argv: &[String],
|
||||
_workdir: &Path,
|
||||
) -> anyhow::Result<EscalateAction> {
|
||||
Ok(self.action.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
Arc::new(DeterministicEscalationPolicy {
|
||||
action: EscalateAction::Run,
|
||||
}),
|
||||
));
|
||||
|
||||
let mut env = HashMap::new();
|
||||
for i in 0..10 {
|
||||
let value = "A".repeat(1024);
|
||||
env.insert(format!("CODEX_TEST_VAR{i}"), value);
|
||||
}
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/echo"),
|
||||
argv: vec!["echo".to_string()],
|
||||
workdir: PathBuf::from("/tmp"),
|
||||
env,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Run,
|
||||
},
|
||||
response
|
||||
);
|
||||
server_task.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
Arc::new(DeterministicEscalationPolicy {
|
||||
action: EscalateAction::Escalate,
|
||||
}),
|
||||
));
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/sh"),
|
||||
argv: vec![
|
||||
"sh".to_string(),
|
||||
"-c".to_string(),
|
||||
r#"if [ "$KEY" = VALUE ]; then exit 42; else exit 1; fi"#.to_string(),
|
||||
],
|
||||
workdir: std::env::current_dir()?,
|
||||
env: HashMap::from([("KEY".to_string(), "VALUE".to_string())]),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
},
|
||||
response
|
||||
);
|
||||
|
||||
client
|
||||
.send_with_fds(SuperExecMessage { fds: Vec::new() }, &[])
|
||||
.await?;
|
||||
|
||||
let result = client.receive::<SuperExecResult>().await?;
|
||||
assert_eq!(42, result.exit_code);
|
||||
|
||||
server_task.await?
|
||||
}
|
||||
}
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
pub mod escalate_client;
|
||||
pub mod escalate_protocol;
|
||||
pub mod escalate_server;
|
||||
@@ -5,3 +6,13 @@ pub mod escalation_policy;
|
||||
pub mod socket;
|
||||
pub mod core_shell_escalation;
|
||||
pub mod stopwatch;
|
||||
||||||| base
|
||||
=======
|
||||
pub mod core_shell_escalation;
|
||||
pub mod escalate_client;
|
||||
pub mod escalate_protocol;
|
||||
pub mod escalate_server;
|
||||
pub mod escalation_policy;
|
||||
pub mod socket;
|
||||
pub mod stopwatch;
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
<<<<<<< local: 4dbedf0a57a9 - mbolin: Use Arc-based ToolCtx in tool runtimes
|
||||
use libc::c_uint;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -505,3 +506,514 @@ mod tests {
|
||||
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
|
||||
}
|
||||
}
|
||||
||||||| base
|
||||
=======
|
||||
use libc::c_uint;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use socket2::Domain;
|
||||
use socket2::MaybeUninitSlice;
|
||||
use socket2::MsgHdr;
|
||||
use socket2::MsgHdrMut;
|
||||
use socket2::Socket;
|
||||
use socket2::Type;
|
||||
use std::io::IoSlice;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
use std::os::fd::OwnedFd;
|
||||
use std::os::fd::RawFd;
|
||||
use tokio::io::Interest;
|
||||
use tokio::io::unix::AsyncFd;
|
||||
|
||||
const MAX_FDS_PER_MESSAGE: usize = 16;
|
||||
const LENGTH_PREFIX_SIZE: usize = size_of::<u32>();
|
||||
const MAX_DATAGRAM_SIZE: usize = 8192;
|
||||
|
||||
/// Converts a slice of MaybeUninit<T> to a slice of T.
|
||||
///
|
||||
/// The caller guarantees that every element of `buf` is initialized.
|
||||
fn assume_init<T>(buf: &[MaybeUninit<T>]) -> &[T] {
|
||||
unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) }
|
||||
}
|
||||
|
||||
fn assume_init_slice<T, const N: usize>(buf: &[MaybeUninit<T>; N]) -> &[T; N] {
|
||||
unsafe { &*(buf as *const [MaybeUninit<T>; N] as *const [T; N]) }
|
||||
}
|
||||
|
||||
fn assume_init_vec<T>(mut buf: Vec<MaybeUninit<T>>) -> Vec<T> {
|
||||
unsafe {
|
||||
let ptr = buf.as_mut_ptr() as *mut T;
|
||||
let len = buf.len();
|
||||
let cap = buf.capacity();
|
||||
std::mem::forget(buf);
|
||||
Vec::from_raw_parts(ptr, len, cap)
|
||||
}
|
||||
}
|
||||
|
||||
fn control_space_for_fds(count: usize) -> usize {
|
||||
unsafe { libc::CMSG_SPACE((count * size_of::<RawFd>()) as _) as usize }
|
||||
}
|
||||
|
||||
/// Extracts the FDs from a SCM_RIGHTS control message.
|
||||
fn extract_fds(control: &[u8]) -> Vec<OwnedFd> {
|
||||
let mut fds = Vec::new();
|
||||
let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
|
||||
hdr.msg_control = control.as_ptr() as *mut libc::c_void;
|
||||
hdr.msg_controllen = control.len() as _;
|
||||
let hdr = hdr; // drop mut
|
||||
|
||||
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr };
|
||||
while !cmsg.is_null() {
|
||||
let level = unsafe { (*cmsg).cmsg_level };
|
||||
let ty = unsafe { (*cmsg).cmsg_type };
|
||||
if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
|
||||
let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::<RawFd>() };
|
||||
let fd_count: usize = {
|
||||
let cmsg_data_len =
|
||||
unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize };
|
||||
cmsg_data_len / size_of::<RawFd>()
|
||||
};
|
||||
for i in 0..fd_count {
|
||||
let fd = unsafe { data_ptr.add(i).read() };
|
||||
fds.push(unsafe { OwnedFd::from_raw_fd(fd) });
|
||||
}
|
||||
}
|
||||
cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) };
|
||||
}
|
||||
fds
|
||||
}
|
||||
|
||||
/// Read a frame from a SOCK_STREAM socket.
|
||||
///
|
||||
/// A frame is a message length prefix followed by a payload. FDs may be included in the control
|
||||
/// message when receiving the frame header.
|
||||
async fn read_frame(async_socket: &AsyncFd<Socket>) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
let (message_len, fds) = read_frame_header(async_socket).await?;
|
||||
let payload = read_frame_payload(async_socket, message_len).await?;
|
||||
Ok((payload, fds))
|
||||
}
|
||||
|
||||
/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket.
|
||||
async fn read_frame_header(
|
||||
async_socket: &AsyncFd<Socket>,
|
||||
) -> std::io::Result<(usize, Vec<OwnedFd>)> {
|
||||
let mut header = [MaybeUninit::<u8>::uninit(); LENGTH_PREFIX_SIZE];
|
||||
let mut filled = 0;
|
||||
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
|
||||
let mut captured_control = false;
|
||||
|
||||
while filled < LENGTH_PREFIX_SIZE {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
// The first read should come with a control message containing any FDs.
|
||||
let read = if !captured_control {
|
||||
match guard.try_io(|inner| {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
|
||||
let (read, control_len) = {
|
||||
let mut msg = MsgHdrMut::new()
|
||||
.with_buffers(&mut bufs)
|
||||
.with_control(&mut control);
|
||||
let read = inner.get_ref().recvmsg(&mut msg, 0)?;
|
||||
(read, msg.control_len())
|
||||
};
|
||||
control.truncate(control_len);
|
||||
captured_control = true;
|
||||
Ok(read)
|
||||
}) {
|
||||
Ok(Ok(read)) => read,
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(_would_block) => continue,
|
||||
}
|
||||
} else {
|
||||
match guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) {
|
||||
Ok(Ok(read)) => read,
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(_would_block) => continue,
|
||||
}
|
||||
};
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"socket closed while receiving frame header",
|
||||
));
|
||||
}
|
||||
|
||||
filled += read;
|
||||
assert!(filled <= LENGTH_PREFIX_SIZE);
|
||||
if filled == LENGTH_PREFIX_SIZE {
|
||||
let len_bytes = assume_init_slice(&header);
|
||||
let payload_len = u32::from_le_bytes(*len_bytes) as usize;
|
||||
let fds = extract_fds(assume_init(&control));
|
||||
return Ok((payload_len, fds));
|
||||
}
|
||||
}
|
||||
unreachable!("header loop always returns")
|
||||
}
|
||||
|
||||
/// Read `message_len` bytes from a SOCK_STREAM socket.
|
||||
async fn read_frame_payload(
|
||||
async_socket: &AsyncFd<Socket>,
|
||||
message_len: usize,
|
||||
) -> std::io::Result<Vec<u8>> {
|
||||
if message_len == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut payload = vec![MaybeUninit::<u8>::uninit(); message_len];
|
||||
let mut filled = 0;
|
||||
while filled < message_len {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
let read = match guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])) {
|
||||
Ok(Ok(read)) => read,
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(_would_block) => continue,
|
||||
};
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"socket closed while receiving frame payload",
|
||||
));
|
||||
}
|
||||
filled += read;
|
||||
assert!(filled <= message_len);
|
||||
if filled == message_len {
|
||||
return Ok(assume_init_vec(payload));
|
||||
}
|
||||
}
|
||||
unreachable!("loop exits only after returning payload")
|
||||
}
|
||||
|
||||
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
let control = make_control_message(fds)?;
|
||||
let payload = [IoSlice::new(data)];
|
||||
let msg = if control.is_empty() {
|
||||
MsgHdr::new().with_buffers(&payload)
|
||||
} else {
|
||||
MsgHdr::new().with_buffers(&payload).with_control(&control)
|
||||
};
|
||||
let written = socket.sendmsg(&msg, 0)?;
|
||||
if written != data.len() {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
format!(
|
||||
"short datagram write: wrote {written} bytes out of {}",
|
||||
data.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> {
|
||||
let len_u32 = u32::try_from(len).map_err(|_| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("message too large: {len}"),
|
||||
)
|
||||
})?;
|
||||
Ok(len_u32.to_le_bytes())
|
||||
}
|
||||
|
||||
fn make_control_message(fds: &[OwnedFd]) -> std::io::Result<Vec<u8>> {
|
||||
if fds.len() > MAX_FDS_PER_MESSAGE {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("too many fds: {}", fds.len()),
|
||||
))
|
||||
} else if fds.is_empty() {
|
||||
Ok(Vec::new())
|
||||
} else {
|
||||
let mut control = vec![0u8; control_space_for_fds(fds.len())];
|
||||
unsafe {
|
||||
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
|
||||
(*cmsg).cmsg_len =
|
||||
libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
|
||||
(*cmsg).cmsg_level = libc::SOL_SOCKET;
|
||||
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
|
||||
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
|
||||
for (i, fd) in fds.iter().enumerate() {
|
||||
data_ptr.add(i).write(fd.as_raw_fd());
|
||||
}
|
||||
}
|
||||
Ok(control)
|
||||
}
|
||||
}
|
||||
|
||||
fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
let mut buffer = vec![MaybeUninit::<u8>::uninit(); MAX_DATAGRAM_SIZE];
|
||||
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
|
||||
let (read, control_len) = {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut buffer)];
|
||||
let mut msg = MsgHdrMut::new()
|
||||
.with_buffers(&mut bufs)
|
||||
.with_control(&mut control);
|
||||
let read = socket.recvmsg(&mut msg, 0)?;
|
||||
(read, msg.control_len())
|
||||
};
|
||||
let data = assume_init(&buffer[..read]).to_vec();
|
||||
let fds = extract_fds(assume_init(&control[..control_len]));
|
||||
Ok((data, fds))
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncSocket {
|
||||
inner: AsyncFd<Socket>,
|
||||
}
|
||||
|
||||
impl AsyncSocket {
|
||||
fn new(socket: Socket) -> std::io::Result<AsyncSocket> {
|
||||
socket.set_nonblocking(true)?;
|
||||
let async_socket = AsyncFd::new(socket)?;
|
||||
Ok(AsyncSocket {
|
||||
inner: async_socket,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_fd(fd: OwnedFd) -> std::io::Result<AsyncSocket> {
|
||||
AsyncSocket::new(Socket::from(fd))
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
|
||||
let (server, client) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?;
|
||||
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
|
||||
}
|
||||
|
||||
pub async fn send_with_fds<T: Serialize>(
|
||||
&self,
|
||||
msg: T,
|
||||
fds: &[OwnedFd],
|
||||
) -> std::io::Result<()> {
|
||||
let payload = serde_json::to_vec(&msg)?;
|
||||
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + payload.len());
|
||||
frame.extend_from_slice(&encode_length(payload.len())?);
|
||||
frame.extend_from_slice(&payload);
|
||||
send_stream_frame(&self.inner, &frame, fds).await
|
||||
}
|
||||
|
||||
pub async fn receive_with_fds<T: for<'de> Deserialize<'de>>(
|
||||
&self,
|
||||
) -> std::io::Result<(T, Vec<OwnedFd>)> {
|
||||
let (payload, fds) = read_frame(&self.inner).await?;
|
||||
let message: T = serde_json::from_slice(&payload)?;
|
||||
Ok((message, fds))
|
||||
}
|
||||
|
||||
pub async fn send<T>(&self, msg: T) -> std::io::Result<()>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
self.send_with_fds(&msg, &[]).await
|
||||
}
|
||||
|
||||
pub async fn receive<T: for<'de> Deserialize<'de>>(&self) -> std::io::Result<T> {
|
||||
let (msg, fds) = self.receive_with_fds().await?;
|
||||
if !fds.is_empty() {
|
||||
tracing::warn!("unexpected fds in receive: {}", fds.len());
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Socket {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_stream_frame(
|
||||
socket: &AsyncFd<Socket>,
|
||||
frame: &[u8],
|
||||
fds: &[OwnedFd],
|
||||
) -> std::io::Result<()> {
|
||||
let mut written = 0;
|
||||
let mut include_fds = !fds.is_empty();
|
||||
while written < frame.len() {
|
||||
let mut guard = socket.writable().await?;
|
||||
let bytes_written = match guard
|
||||
.try_io(|inner| send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds))
|
||||
{
|
||||
Ok(Ok(bytes_written)) => bytes_written,
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(_would_block) => continue,
|
||||
};
|
||||
if bytes_written == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"socket closed while sending frame payload",
|
||||
));
|
||||
}
|
||||
written += bytes_written;
|
||||
include_fds = false;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_stream_chunk(
|
||||
socket: &Socket,
|
||||
frame: &[u8],
|
||||
fds: &[OwnedFd],
|
||||
include_fds: bool,
|
||||
) -> std::io::Result<usize> {
|
||||
let control = if include_fds {
|
||||
make_control_message(fds)?
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let payload = [IoSlice::new(frame)];
|
||||
let msg = if control.is_empty() {
|
||||
MsgHdr::new().with_buffers(&payload)
|
||||
} else {
|
||||
MsgHdr::new().with_buffers(&payload).with_control(&control)
|
||||
};
|
||||
socket.sendmsg(&msg, 0)
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncDatagramSocket {
|
||||
inner: AsyncFd<Socket>,
|
||||
}
|
||||
|
||||
impl AsyncDatagramSocket {
|
||||
fn new(socket: Socket) -> std::io::Result<Self> {
|
||||
socket.set_nonblocking(true)?;
|
||||
Ok(Self {
|
||||
inner: AsyncFd::new(socket)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result<Self> {
|
||||
Self::new(unsafe { Socket::from_raw_fd(fd) })
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(Self, Self)> {
|
||||
let (server, client) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?;
|
||||
Ok((Self::new(server)?, Self::new(client)?))
|
||||
}
|
||||
|
||||
pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
self.inner
|
||||
.async_io(Interest::WRITABLE, |socket| {
|
||||
send_datagram_bytes(socket, data, fds)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn receive_with_fds(&self) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
self.inner
|
||||
.async_io(Interest::READABLE, receive_datagram_bytes)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Socket {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::os::fd::AsFd;
|
||||
use std::os::fd::AsRawFd;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
struct TestPayload {
|
||||
id: i32,
|
||||
label: String,
|
||||
}
|
||||
|
||||
fn fd_list(count: usize) -> std::io::Result<Vec<OwnedFd>> {
|
||||
let file = NamedTempFile::new()?;
|
||||
let mut fds = Vec::new();
|
||||
for _ in 0..count {
|
||||
fds.push(file.as_fd().try_clone_to_owned()?);
|
||||
}
|
||||
Ok(fds)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let payload = TestPayload {
|
||||
id: 7,
|
||||
label: "round-trip".to_string(),
|
||||
};
|
||||
let send_fds = fd_list(1)?;
|
||||
|
||||
let receive_task =
|
||||
tokio::spawn(async move { server.receive_with_fds::<TestPayload>().await });
|
||||
client.send_with_fds(payload.clone(), &send_fds).await?;
|
||||
drop(send_fds);
|
||||
|
||||
let (received_payload, received_fds) = receive_task.await.unwrap()?;
|
||||
assert_eq!(payload, received_payload);
|
||||
assert_eq!(1, received_fds.len());
|
||||
let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) };
|
||||
assert!(
|
||||
fd_status >= 0,
|
||||
"expected received file descriptor to be valid, but got {fd_status}",
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_socket_handles_large_payload() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let payload = vec![b'A'; 10_000];
|
||||
let receive_task = tokio::spawn(async move { server.receive::<Vec<u8>>().await });
|
||||
client.send(payload.clone()).await?;
|
||||
let received_payload = receive_task.await.unwrap()?;
|
||||
assert_eq!(payload, received_payload);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncDatagramSocket::pair()?;
|
||||
let data = b"datagram payload".to_vec();
|
||||
let send_fds = fd_list(1)?;
|
||||
let receive_task = tokio::spawn(async move { server.receive_with_fds().await });
|
||||
|
||||
client.send_with_fds(&data, &send_fds).await?;
|
||||
drop(send_fds);
|
||||
|
||||
let (received_bytes, received_fds) = receive_task.await.unwrap()?;
|
||||
assert_eq!(data, received_bytes);
|
||||
assert_eq!(1, received_fds.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_length_errors_for_oversized_messages() {
|
||||
let err = encode_length(usize::MAX).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_fails_when_peer_closes_before_header() {
|
||||
let (server, client) = AsyncSocket::pair().expect("failed to create socket pair");
|
||||
drop(client);
|
||||
let err = server
|
||||
.receive::<serde_json::Value>()
|
||||
.await
|
||||
.expect_err("expected read failure");
|
||||
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
|
||||
}
|
||||
}
|
||||
>>>>>>> graft: 3aab533f2d22 - mbolin: Support zsh shell tool with shell-escal...
|
||||
|
||||
Reference in New Issue
Block a user