mirror of
https://github.com/openai/codex.git
synced 2026-02-25 18:23:47 +00:00
Compare commits
2 Commits
dev/cc/new
...
pr12615
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
beaffad47e | ||
|
|
b9fe46fb04 |
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -1859,6 +1859,7 @@ dependencies = [
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
@@ -32,6 +32,8 @@ codex-core = { workspace = true }
|
||||
codex-execpolicy = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-shell-command = { workspace = true }
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
codex-shell-escalation = { workspace = true }
|
||||
rmcp = { workspace = true, default-features = false, features = [
|
||||
"auth",
|
||||
@@ -51,6 +53,7 @@ serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] }
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ use codex_execpolicy::Decision;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_execpolicy::RuleMatch;
|
||||
use codex_shell_command::is_dangerous_command::command_might_be_dangerous;
|
||||
use codex_shell_escalation as shell_escalation;
|
||||
use codex_shell_escalation::unix::escalate_client::run;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
@@ -160,7 +160,7 @@ pub async fn main_execve_wrapper() -> anyhow::Result<()> {
|
||||
.init();
|
||||
|
||||
let ExecveWrapperCli { file, argv } = ExecveWrapperCli::parse();
|
||||
let exit_code = shell_escalation::run(file, argv).await?;
|
||||
let exit_code = run(file, argv).await?;
|
||||
std::process::exit(exit_code);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,19 @@ use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use codex_core::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
use codex_core::MCP_SANDBOX_STATE_METHOD;
|
||||
use codex_core::SandboxState;
|
||||
use codex_core::SandboxState as CoreSandboxState;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_protocol::config_types::WindowsSandboxLevel;
|
||||
use codex_protocol::models::SandboxPermissions as ProtocolSandboxPermissions;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_shell_escalation::EscalationPolicyFactory;
|
||||
use codex_shell_escalation::run_escalate_server;
|
||||
use codex_shell_escalation::unix::escalate_server::EscalationPolicyFactory;
|
||||
use codex_shell_escalation::unix::escalate_server::ExecParams as ShellExecParams;
|
||||
use codex_shell_escalation::unix::escalate_server::ExecResult as ShellExecResult;
|
||||
use codex_shell_escalation::unix::escalate_server::SandboxState as ShellEscalationSandboxState;
|
||||
use codex_shell_escalation::unix::escalate_server::ShellCommandExecutor;
|
||||
use codex_shell_escalation::unix::escalate_server::run_escalate_server;
|
||||
use codex_shell_escalation::unix::stopwatch::Stopwatch;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::RoleServer;
|
||||
use rmcp::ServerHandler;
|
||||
@@ -27,7 +35,9 @@ use rmcp::tool_handler;
|
||||
use rmcp::tool_router;
|
||||
use rmcp::transport::stdio;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::unix::mcp_escalation_policy::McpEscalationPolicy;
|
||||
|
||||
@@ -50,8 +60,8 @@ pub struct ExecResult {
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
impl From<codex_shell_escalation::ExecResult> for ExecResult {
|
||||
fn from(result: codex_shell_escalation::ExecResult) -> Self {
|
||||
impl From<ShellExecResult> for ExecResult {
|
||||
fn from(result: ShellExecResult) -> Self {
|
||||
Self {
|
||||
exit_code: result.exit_code,
|
||||
output: result.output,
|
||||
@@ -68,7 +78,7 @@ pub struct ExecTool {
|
||||
execve_wrapper: PathBuf,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
preserve_program_paths: bool,
|
||||
sandbox_state: Arc<RwLock<Option<SandboxState>>>,
|
||||
sandbox_state: Arc<RwLock<Option<CoreSandboxState>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize, rmcp::schemars::JsonSchema)]
|
||||
@@ -83,7 +93,7 @@ pub struct ExecParams {
|
||||
pub login: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<ExecParams> for codex_shell_escalation::ExecParams {
|
||||
impl From<ExecParams> for ShellExecParams {
|
||||
fn from(inner: ExecParams) -> Self {
|
||||
Self {
|
||||
command: inner.command,
|
||||
@@ -99,14 +109,51 @@ struct McpEscalationPolicyFactory {
|
||||
preserve_program_paths: bool,
|
||||
}
|
||||
|
||||
struct McpShellCommandExecutor;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ShellCommandExecutor for McpShellCommandExecutor {
|
||||
async fn run(
|
||||
&self,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
env: HashMap<String, String>,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &ShellEscalationSandboxState,
|
||||
) -> anyhow::Result<ShellExecResult> {
|
||||
let result = process_exec_tool_call(
|
||||
codex_core::exec::ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx),
|
||||
env,
|
||||
network: None,
|
||||
sandbox_permissions: ProtocolSandboxPermissions::UseDefault,
|
||||
windows_sandbox_level: WindowsSandboxLevel::Disabled,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
&sandbox_state.sandbox_policy,
|
||||
&sandbox_state.sandbox_cwd,
|
||||
&sandbox_state.codex_linux_sandbox_exe,
|
||||
sandbox_state.use_linux_sandbox_bwrap,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ShellExecResult {
|
||||
exit_code: result.exit_code,
|
||||
output: result.aggregated_output.text,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EscalationPolicyFactory for McpEscalationPolicyFactory {
|
||||
type Policy = McpEscalationPolicy;
|
||||
|
||||
fn create_policy(
|
||||
&self,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
stopwatch: codex_shell_escalation::Stopwatch,
|
||||
) -> Self::Policy {
|
||||
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy {
|
||||
McpEscalationPolicy::new(
|
||||
policy,
|
||||
self.context.clone(),
|
||||
@@ -151,15 +198,21 @@ impl ExecTool {
|
||||
.read()
|
||||
.await
|
||||
.clone()
|
||||
.unwrap_or_else(|| SandboxState {
|
||||
.unwrap_or_else(|| CoreSandboxState {
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
codex_linux_sandbox_exe: None,
|
||||
sandbox_cwd: PathBuf::from(¶ms.workdir),
|
||||
use_linux_sandbox_bwrap: false,
|
||||
});
|
||||
let shell_sandbox_state = ShellEscalationSandboxState {
|
||||
sandbox_policy: sandbox_state.sandbox_policy.clone(),
|
||||
codex_linux_sandbox_exe: sandbox_state.codex_linux_sandbox_exe.clone(),
|
||||
sandbox_cwd: sandbox_state.sandbox_cwd.clone(),
|
||||
use_linux_sandbox_bwrap: sandbox_state.use_linux_sandbox_bwrap,
|
||||
};
|
||||
let result = run_escalate_server(
|
||||
params.into(),
|
||||
&sandbox_state,
|
||||
&shell_sandbox_state,
|
||||
&self.bash_path,
|
||||
&self.execve_wrapper,
|
||||
self.policy.clone(),
|
||||
@@ -168,6 +221,7 @@ impl ExecTool {
|
||||
preserve_program_paths: self.preserve_program_paths,
|
||||
},
|
||||
effective_timeout,
|
||||
&McpShellCommandExecutor,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
|
||||
@@ -236,7 +290,7 @@ impl ServerHandler for ExecTool {
|
||||
));
|
||||
};
|
||||
|
||||
let Ok(sandbox_state) = serde_json::from_value::<SandboxState>(params.clone()) else {
|
||||
let Ok(sandbox_state) = serde_json::from_value::<CoreSandboxState>(params.clone()) else {
|
||||
return Err(McpError::invalid_params(
|
||||
"failed to deserialize sandbox state".to_string(),
|
||||
Some(params),
|
||||
|
||||
@@ -2,9 +2,9 @@ use std::path::Path;
|
||||
|
||||
use codex_core::sandboxing::SandboxPermissions;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_shell_escalation::EscalateAction;
|
||||
use codex_shell_escalation::EscalationPolicy;
|
||||
use codex_shell_escalation::Stopwatch;
|
||||
use codex_shell_escalation::unix::escalate_protocol::EscalateAction;
|
||||
use codex_shell_escalation::unix::escalation_policy::EscalationPolicy;
|
||||
use codex_shell_escalation::unix::stopwatch::Stopwatch;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::RoleServer;
|
||||
use rmcp::model::CreateElicitationRequestParams;
|
||||
|
||||
@@ -14,13 +14,15 @@ libc = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
path-absolutize = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
socket2 = { workspace = true }
|
||||
socket2 = { workspace = true, features = ["all"] }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"net",
|
||||
"macros",
|
||||
"process",
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
"time",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
@@ -1,21 +1,114 @@
|
||||
#[cfg(unix)]
|
||||
mod unix {
|
||||
mod escalate_client;
|
||||
mod escalate_protocol;
|
||||
mod escalate_server;
|
||||
mod escalation_policy;
|
||||
mod socket;
|
||||
mod stopwatch;
|
||||
|
||||
pub use self::escalate_client::run;
|
||||
pub use self::escalate_protocol::EscalateAction;
|
||||
pub use self::escalate_server::EscalationPolicyFactory;
|
||||
pub use self::escalate_server::ExecParams;
|
||||
pub use self::escalate_server::ExecResult;
|
||||
pub use self::escalate_server::run_escalate_server;
|
||||
pub use self::escalation_policy::EscalationPolicy;
|
||||
pub use self::stopwatch::Stopwatch;
|
||||
}
|
||||
pub mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::escalate_client::run;
|
||||
#[cfg(unix)]
|
||||
pub use unix::escalate_protocol::EscalateAction;
|
||||
#[cfg(unix)]
|
||||
pub use unix::escalate_server::EscalationPolicyFactory;
|
||||
#[cfg(unix)]
|
||||
pub use unix::escalate_server::ExecParams;
|
||||
#[cfg(unix)]
|
||||
pub use unix::escalate_server::ExecResult;
|
||||
#[cfg(unix)]
|
||||
pub use unix::escalation_policy::EscalationPolicy;
|
||||
#[cfg(unix)]
|
||||
pub use unix::stopwatch::Stopwatch;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod legacy_api {
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_protocol::config_types::WindowsSandboxLevel;
|
||||
use codex_protocol::models::SandboxPermissions as ProtocolSandboxPermissions;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::unix::escalate_server::EscalationPolicyFactory;
|
||||
use crate::unix::escalate_server::ExecParams;
|
||||
use crate::unix::escalate_server::ExecResult;
|
||||
use crate::unix::escalate_server::SandboxState;
|
||||
use crate::unix::escalate_server::ShellCommandExecutor;
|
||||
|
||||
struct CoreShellCommandExecutor;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ShellCommandExecutor for CoreShellCommandExecutor {
|
||||
async fn run(
|
||||
&self,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
env: HashMap<String, String>,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &SandboxState,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let result = codex_core::exec::process_exec_tool_call(
|
||||
codex_core::exec::ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx),
|
||||
env,
|
||||
network: None,
|
||||
sandbox_permissions: ProtocolSandboxPermissions::UseDefault,
|
||||
windows_sandbox_level: WindowsSandboxLevel::Disabled,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
&sandbox_state.sandbox_policy,
|
||||
&sandbox_state.sandbox_cwd,
|
||||
&sandbox_state.codex_linux_sandbox_exe,
|
||||
sandbox_state.use_linux_sandbox_bwrap,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ExecResult {
|
||||
exit_code: result.exit_code,
|
||||
output: result.aggregated_output.text,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run_escalate_server(
|
||||
exec_params: ExecParams,
|
||||
sandbox_state: &codex_core::SandboxState,
|
||||
shell_program: impl AsRef<Path>,
|
||||
execve_wrapper: impl AsRef<Path>,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
escalation_policy_factory: impl EscalationPolicyFactory,
|
||||
effective_timeout: Duration,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let sandbox_state = SandboxState {
|
||||
sandbox_policy: sandbox_state.sandbox_policy.clone(),
|
||||
codex_linux_sandbox_exe: sandbox_state.codex_linux_sandbox_exe.clone(),
|
||||
sandbox_cwd: sandbox_state.sandbox_cwd.clone(),
|
||||
use_linux_sandbox_bwrap: sandbox_state.use_linux_sandbox_bwrap,
|
||||
};
|
||||
crate::unix::escalate_server::run_escalate_server(
|
||||
exec_params,
|
||||
&sandbox_state,
|
||||
shell_program,
|
||||
execve_wrapper,
|
||||
policy,
|
||||
escalation_policy_factory,
|
||||
effective_timeout,
|
||||
&CoreShellCommandExecutor,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use legacy_api::run_escalate_server;
|
||||
|
||||
@@ -40,7 +40,7 @@ impl ShellPolicyFactory {
|
||||
}
|
||||
}
|
||||
|
||||
struct ShellEscalationPolicy {
|
||||
pub struct ShellEscalationPolicy {
|
||||
provider: Arc<dyn ShellActionProvider>,
|
||||
stopwatch: Stopwatch,
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use codex_core::SandboxState;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use path_absolutize::Absolutize as _;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::RwLock;
|
||||
@@ -27,6 +27,33 @@ use crate::unix::socket::AsyncDatagramSocket;
|
||||
use crate::unix::socket::AsyncSocket;
|
||||
use crate::unix::stopwatch::Stopwatch;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Sandbox configuration forwarded to the embedding crate's process executor.
|
||||
pub struct SandboxState {
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub sandbox_cwd: PathBuf,
|
||||
pub use_linux_sandbox_bwrap: bool,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
/// Adapter for running the shell command after the escalation server has been set up.
|
||||
///
|
||||
/// This lets `shell-escalation` own the Unix escalation protocol while the caller
|
||||
/// (for example `codex-core` or `exec-server`) keeps control over process spawning,
|
||||
/// output capture, and sandbox integration.
|
||||
pub trait ShellCommandExecutor: Send + Sync {
|
||||
/// Runs the requested shell command and returns the captured result.
|
||||
async fn run(
|
||||
&self,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
env: HashMap<String, String>,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &SandboxState,
|
||||
) -> anyhow::Result<ExecResult>;
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, serde::Serialize)]
|
||||
pub struct ExecParams {
|
||||
/// The bash string to execute.
|
||||
@@ -71,11 +98,12 @@ impl EscalateServer {
|
||||
params: ExecParams,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &SandboxState,
|
||||
command_executor: &dyn ShellCommandExecutor,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
|
||||
let client_socket = escalate_client.into_inner();
|
||||
// Only the client endpoint should cross exec into the wrapper process.
|
||||
client_socket.set_cloexec(false)?;
|
||||
|
||||
let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone()));
|
||||
let mut env = std::env::vars().collect::<HashMap<String, String>>();
|
||||
env.insert(
|
||||
@@ -91,47 +119,27 @@ impl EscalateServer {
|
||||
self.execve_wrapper.to_string_lossy().to_string(),
|
||||
);
|
||||
|
||||
let ExecParams {
|
||||
command,
|
||||
workdir,
|
||||
timeout_ms: _,
|
||||
login,
|
||||
} = params;
|
||||
let result = codex_core::exec::process_exec_tool_call(
|
||||
codex_core::exec::ExecParams {
|
||||
command: vec![
|
||||
self.bash_path.to_string_lossy().to_string(),
|
||||
if login == Some(false) {
|
||||
"-c".to_string()
|
||||
} else {
|
||||
"-lc".to_string()
|
||||
},
|
||||
command,
|
||||
],
|
||||
cwd: PathBuf::from(&workdir),
|
||||
expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx),
|
||||
env,
|
||||
network: None,
|
||||
sandbox_permissions: codex_core::sandboxing::SandboxPermissions::UseDefault,
|
||||
windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
let command = vec![
|
||||
self.bash_path.to_string_lossy().to_string(),
|
||||
if params.login == Some(false) {
|
||||
"-c".to_string()
|
||||
} else {
|
||||
"-lc".to_string()
|
||||
},
|
||||
&sandbox_state.sandbox_policy,
|
||||
&sandbox_state.sandbox_cwd,
|
||||
&sandbox_state.codex_linux_sandbox_exe,
|
||||
sandbox_state.use_linux_sandbox_bwrap,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
params.command,
|
||||
];
|
||||
let result = command_executor
|
||||
.run(
|
||||
command,
|
||||
PathBuf::from(¶ms.workdir),
|
||||
env,
|
||||
cancel_rx,
|
||||
sandbox_state,
|
||||
)
|
||||
.await?;
|
||||
escalate_task.abort();
|
||||
|
||||
Ok(ExecResult {
|
||||
exit_code: result.exit_code,
|
||||
output: result.aggregated_output.text,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
})
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,6 +150,7 @@ pub trait EscalationPolicyFactory {
|
||||
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run_escalate_server(
|
||||
exec_params: ExecParams,
|
||||
sandbox_state: &SandboxState,
|
||||
@@ -150,6 +159,7 @@ pub async fn run_escalate_server(
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
escalation_policy_factory: impl EscalationPolicyFactory,
|
||||
effective_timeout: Duration,
|
||||
command_executor: &dyn ShellCommandExecutor,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let stopwatch = Stopwatch::new(effective_timeout);
|
||||
let cancel_token = stopwatch.cancellation_token();
|
||||
@@ -160,7 +170,7 @@ pub async fn run_escalate_server(
|
||||
);
|
||||
|
||||
escalate_server
|
||||
.exec(exec_params, cancel_token, sandbox_state)
|
||||
.exec(exec_params, cancel_token, sandbox_state, command_executor)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -272,6 +282,7 @@ async fn handle_escalate_session_with_policy(
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -279,7 +290,6 @@ async fn handle_escalate_session_with_policy(
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
pub mod core_shell_escalation;
|
||||
pub mod escalate_client;
|
||||
pub mod escalate_protocol;
|
||||
pub mod escalate_server;
|
||||
pub mod escalation_policy;
|
||||
pub mod socket;
|
||||
pub mod core_shell_escalation;
|
||||
pub mod stopwatch;
|
||||
|
||||
@@ -96,8 +96,8 @@ async fn read_frame_header(
|
||||
while filled < LENGTH_PREFIX_SIZE {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
// The first read should come with a control message containing any FDs.
|
||||
let result = if !captured_control {
|
||||
guard.try_io(|inner| {
|
||||
let read = if !captured_control {
|
||||
match guard.try_io(|inner| {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
|
||||
let (read, control_len) = {
|
||||
let mut msg = MsgHdrMut::new()
|
||||
@@ -109,16 +109,18 @@ async fn read_frame_header(
|
||||
control.truncate(control_len);
|
||||
captured_control = true;
|
||||
Ok(read)
|
||||
})
|
||||
}) {
|
||||
Ok(Ok(read)) => read,
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(_would_block) => continue,
|
||||
}
|
||||
} else {
|
||||
guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..]))
|
||||
match guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) {
|
||||
Ok(Ok(read)) => read,
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(_would_block) => continue,
|
||||
}
|
||||
};
|
||||
let Ok(result) = result else {
|
||||
// Would block, try again.
|
||||
continue;
|
||||
};
|
||||
|
||||
let read = result?;
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
@@ -150,12 +152,11 @@ async fn read_frame_payload(
|
||||
let mut filled = 0;
|
||||
while filled < message_len {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..]));
|
||||
let Ok(result) = result else {
|
||||
// Would block, try again.
|
||||
continue;
|
||||
let read = match guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])) {
|
||||
Ok(Ok(read)) => read,
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(_would_block) => continue,
|
||||
};
|
||||
let read = result?;
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
@@ -261,7 +262,13 @@ impl AsyncSocket {
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
|
||||
let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
|
||||
// `socket2::Socket::pair()` also applies "common flags" (including
|
||||
// `SO_NOSIGPIPE` on Apple platforms), which can fail for AF_UNIX sockets.
|
||||
// Use `pair_raw()` to avoid those side effects, then restore `CLOEXEC`
|
||||
// explicitly on both endpoints.
|
||||
let (server, client) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?;
|
||||
server.set_cloexec(true)?;
|
||||
client.set_cloexec(true)?;
|
||||
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
|
||||
}
|
||||
|
||||
@@ -314,11 +321,11 @@ async fn send_stream_frame(
|
||||
let mut include_fds = !fds.is_empty();
|
||||
while written < frame.len() {
|
||||
let mut guard = socket.writable().await?;
|
||||
let result = guard.try_io(|inner| {
|
||||
send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds)
|
||||
});
|
||||
let bytes_written = match result {
|
||||
Ok(bytes_written) => bytes_written?,
|
||||
let bytes_written = match guard
|
||||
.try_io(|inner| send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds))
|
||||
{
|
||||
Ok(Ok(bytes_written)) => bytes_written,
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(_would_block) => continue,
|
||||
};
|
||||
if bytes_written == 0 {
|
||||
@@ -370,7 +377,13 @@ impl AsyncDatagramSocket {
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(Self, Self)> {
|
||||
let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
|
||||
// `socket2::Socket::pair()` also applies "common flags" (including
|
||||
// `SO_NOSIGPIPE` on Apple platforms), which can fail for AF_UNIX sockets.
|
||||
// Use `pair_raw()` to avoid those side effects, then restore `CLOEXEC`
|
||||
// explicitly on both endpoints.
|
||||
let (server, client) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?;
|
||||
server.set_cloexec(true)?;
|
||||
client.set_cloexec(true)?;
|
||||
Ok((Self::new(server)?, Self::new(client)?))
|
||||
}
|
||||
|
||||
@@ -472,7 +485,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
|
||||
let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
@@ -481,7 +494,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
|
||||
let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
|
||||
Reference in New Issue
Block a user