feat: shell snapshotting (#7641)

This commit is contained in:
jif-oai
2025-12-09 18:36:58 +00:00
committed by GitHub
parent ac3237721e
commit 7836aeddae
14 changed files with 807 additions and 48 deletions

View File

@@ -46,7 +46,7 @@ impl ToolHandler for ApplyPatchHandler {
)
}
fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
true
}

View File

@@ -76,7 +76,7 @@ impl ToolHandler for ShellHandler {
)
}
fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
async fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
match &invocation.payload {
ToolPayload::Function { arguments } => {
serde_json::from_str::<ShellToolCallParams>(arguments)
@@ -148,7 +148,7 @@ impl ToolHandler for ShellCommandHandler {
matches!(payload, ToolPayload::Function { .. })
}
fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
async fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
let ToolPayload::Function { arguments } = &invocation.payload else {
return true;
};
@@ -307,18 +307,21 @@ mod tests {
let bash_shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot: None,
};
assert_safe(&bash_shell, "ls -la");
let zsh_shell = Shell {
shell_type: ShellType::Zsh,
shell_path: PathBuf::from("/bin/zsh"),
shell_snapshot: None,
};
assert_safe(&zsh_shell, "ls -la");
let powershell = Shell {
shell_type: ShellType::PowerShell,
shell_path: PathBuf::from("pwsh.exe"),
shell_snapshot: None,
};
assert_safe(&powershell, "ls -Name");
}

View File

@@ -1,12 +1,10 @@
use std::path::PathBuf;
use crate::function_tool::FunctionCallError;
use crate::is_safe_command::is_known_safe_command;
use crate::protocol::EventMsg;
use crate::protocol::ExecCommandOutputDeltaEvent;
use crate::protocol::ExecCommandSource;
use crate::protocol::ExecOutputStream;
use crate::shell::default_user_shell;
use crate::shell::Shell;
use crate::shell::get_shell_by_model_provided_path;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
@@ -24,6 +22,8 @@ use crate::unified_exec::UnifiedExecSessionManager;
use crate::unified_exec::WriteStdinRequest;
use async_trait::async_trait;
use serde::Deserialize;
use std::path::PathBuf;
use std::sync::Arc;
pub struct UnifiedExecHandler;
@@ -34,8 +34,8 @@ struct ExecCommandArgs {
workdir: Option<String>,
#[serde(default)]
shell: Option<String>,
#[serde(default = "default_login")]
login: bool,
#[serde(default)]
login: Option<bool>,
#[serde(default = "default_exec_yield_time_ms")]
yield_time_ms: u64,
#[serde(default)]
@@ -66,10 +66,6 @@ fn default_write_stdin_yield_time_ms() -> u64 {
250
}
fn default_login() -> bool {
true
}
#[async_trait]
impl ToolHandler for UnifiedExecHandler {
fn kind(&self) -> ToolKind {
@@ -83,7 +79,7 @@ impl ToolHandler for UnifiedExecHandler {
)
}
fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
async fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
let (ToolPayload::Function { arguments } | ToolPayload::UnifiedExec { arguments }) =
&invocation.payload
else {
@@ -93,7 +89,7 @@ impl ToolHandler for UnifiedExecHandler {
let Ok(params) = serde_json::from_str::<ExecCommandArgs>(arguments) else {
return true;
};
let command = get_command(&params);
let command = get_command(&params, invocation.session.user_shell());
!is_known_safe_command(&command)
}
@@ -130,9 +126,10 @@ impl ToolHandler for UnifiedExecHandler {
})?;
let process_id = manager.allocate_process_id().await;
let command = get_command(&args);
let command_for_intercept = get_command(&args, session.user_shell());
let ExecCommandArgs {
workdir,
login,
yield_time_ms,
max_output_tokens,
with_escalated_permissions,
@@ -159,7 +156,7 @@ impl ToolHandler for UnifiedExecHandler {
let cwd = workdir.clone().unwrap_or_else(|| context.turn.cwd.clone());
if let Some(output) = intercept_apply_patch(
&command,
&command_for_intercept,
&cwd,
Some(yield_time_ms),
context.session.as_ref(),
@@ -180,6 +177,14 @@ impl ToolHandler for UnifiedExecHandler {
&context.call_id,
None,
);
let command = if login.is_none() {
context
.session
.user_shell()
.wrap_command_with_snapshot(&command_for_intercept)
} else {
command_for_intercept
};
let emitter = ToolEmitter::unified_exec(
&command,
cwd.clone(),
@@ -254,14 +259,15 @@ impl ToolHandler for UnifiedExecHandler {
}
}
fn get_command(args: &ExecCommandArgs) -> Vec<String> {
let shell = if let Some(shell_str) = &args.shell {
get_shell_by_model_provided_path(&PathBuf::from(shell_str))
} else {
default_user_shell()
};
fn get_command(args: &ExecCommandArgs, session_shell: Arc<Shell>) -> Vec<String> {
if let Some(shell_str) = &args.shell {
let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str));
shell.shell_snapshot = None;
return shell.derive_exec_args(&args.cmd, args.login.unwrap_or(true));
}
shell.derive_exec_args(&args.cmd, args.login)
let use_login_shell = args.login.unwrap_or(session_shell.shell_snapshot.is_none());
session_shell.derive_exec_args(&args.cmd, use_login_shell)
}
fn format_response(response: &UnifiedExecResponse) -> String {
@@ -296,6 +302,8 @@ fn format_response(response: &UnifiedExecResponse) -> String {
#[cfg(test)]
mod tests {
use super::*;
use crate::shell::default_user_shell;
use std::sync::Arc;
#[test]
fn test_get_command_uses_default_shell_when_unspecified() {
@@ -306,7 +314,7 @@ mod tests {
assert!(args.shell.is_none());
let command = get_command(&args);
let command = get_command(&args, Arc::new(default_user_shell()));
assert_eq!(command.len(), 3);
assert_eq!(command[2], "echo hello");
@@ -321,7 +329,7 @@ mod tests {
assert_eq!(args.shell.as_deref(), Some("/bin/bash"));
let command = get_command(&args);
let command = get_command(&args, Arc::new(default_user_shell()));
assert_eq!(command[2], "echo hello");
}
@@ -335,7 +343,7 @@ mod tests {
assert_eq!(args.shell.as_deref(), Some("powershell"));
let command = get_command(&args);
let command = get_command(&args, Arc::new(default_user_shell()));
assert_eq!(command[2], "echo hello");
}
@@ -349,7 +357,7 @@ mod tests {
assert_eq!(args.shell.as_deref(), Some("cmd"));
let command = get_command(&args);
let command = get_command(&args, Arc::new(default_user_shell()));
assert_eq!(command[2], "echo hello");
}

View File

@@ -30,7 +30,7 @@ pub trait ToolHandler: Send + Sync {
)
}
fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
false
}
@@ -110,7 +110,7 @@ impl ToolRegistry {
let output_cell = &output_cell;
let invocation = invocation;
async move {
if handler.is_mutating(&invocation) {
if handler.is_mutating(&invocation).await {
tracing::trace!("waiting for tool gate");
invocation.turn.tool_call_gate.wait_ready().await;
tracing::trace!("tool gate released");