mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
feat: shell snapshotting (#7641)
This commit is contained in:
@@ -46,7 +46,7 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
)
|
||||
}
|
||||
|
||||
fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
|
||||
async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ impl ToolHandler for ShellHandler {
|
||||
)
|
||||
}
|
||||
|
||||
fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
|
||||
async fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
|
||||
match &invocation.payload {
|
||||
ToolPayload::Function { arguments } => {
|
||||
serde_json::from_str::<ShellToolCallParams>(arguments)
|
||||
@@ -148,7 +148,7 @@ impl ToolHandler for ShellCommandHandler {
|
||||
matches!(payload, ToolPayload::Function { .. })
|
||||
}
|
||||
|
||||
fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
|
||||
async fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
|
||||
let ToolPayload::Function { arguments } = &invocation.payload else {
|
||||
return true;
|
||||
};
|
||||
@@ -307,18 +307,21 @@ mod tests {
|
||||
let bash_shell = Shell {
|
||||
shell_type: ShellType::Bash,
|
||||
shell_path: PathBuf::from("/bin/bash"),
|
||||
shell_snapshot: None,
|
||||
};
|
||||
assert_safe(&bash_shell, "ls -la");
|
||||
|
||||
let zsh_shell = Shell {
|
||||
shell_type: ShellType::Zsh,
|
||||
shell_path: PathBuf::from("/bin/zsh"),
|
||||
shell_snapshot: None,
|
||||
};
|
||||
assert_safe(&zsh_shell, "ls -la");
|
||||
|
||||
let powershell = Shell {
|
||||
shell_type: ShellType::PowerShell,
|
||||
shell_path: PathBuf::from("pwsh.exe"),
|
||||
shell_snapshot: None,
|
||||
};
|
||||
assert_safe(&powershell, "ls -Name");
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::is_safe_command::is_known_safe_command;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandOutputDeltaEvent;
|
||||
use crate::protocol::ExecCommandSource;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::shell::default_user_shell;
|
||||
use crate::shell::Shell;
|
||||
use crate::shell::get_shell_by_model_provided_path;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
@@ -24,6 +22,8 @@ use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::unified_exec::WriteStdinRequest;
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct UnifiedExecHandler;
|
||||
|
||||
@@ -34,8 +34,8 @@ struct ExecCommandArgs {
|
||||
workdir: Option<String>,
|
||||
#[serde(default)]
|
||||
shell: Option<String>,
|
||||
#[serde(default = "default_login")]
|
||||
login: bool,
|
||||
#[serde(default)]
|
||||
login: Option<bool>,
|
||||
#[serde(default = "default_exec_yield_time_ms")]
|
||||
yield_time_ms: u64,
|
||||
#[serde(default)]
|
||||
@@ -66,10 +66,6 @@ fn default_write_stdin_yield_time_ms() -> u64 {
|
||||
250
|
||||
}
|
||||
|
||||
fn default_login() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for UnifiedExecHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
@@ -83,7 +79,7 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
)
|
||||
}
|
||||
|
||||
fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
|
||||
async fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
|
||||
let (ToolPayload::Function { arguments } | ToolPayload::UnifiedExec { arguments }) =
|
||||
&invocation.payload
|
||||
else {
|
||||
@@ -93,7 +89,7 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
let Ok(params) = serde_json::from_str::<ExecCommandArgs>(arguments) else {
|
||||
return true;
|
||||
};
|
||||
let command = get_command(¶ms);
|
||||
let command = get_command(¶ms, invocation.session.user_shell());
|
||||
!is_known_safe_command(&command)
|
||||
}
|
||||
|
||||
@@ -130,9 +126,10 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
})?;
|
||||
let process_id = manager.allocate_process_id().await;
|
||||
|
||||
let command = get_command(&args);
|
||||
let command_for_intercept = get_command(&args, session.user_shell());
|
||||
let ExecCommandArgs {
|
||||
workdir,
|
||||
login,
|
||||
yield_time_ms,
|
||||
max_output_tokens,
|
||||
with_escalated_permissions,
|
||||
@@ -159,7 +156,7 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
let cwd = workdir.clone().unwrap_or_else(|| context.turn.cwd.clone());
|
||||
|
||||
if let Some(output) = intercept_apply_patch(
|
||||
&command,
|
||||
&command_for_intercept,
|
||||
&cwd,
|
||||
Some(yield_time_ms),
|
||||
context.session.as_ref(),
|
||||
@@ -180,6 +177,14 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
&context.call_id,
|
||||
None,
|
||||
);
|
||||
let command = if login.is_none() {
|
||||
context
|
||||
.session
|
||||
.user_shell()
|
||||
.wrap_command_with_snapshot(&command_for_intercept)
|
||||
} else {
|
||||
command_for_intercept
|
||||
};
|
||||
let emitter = ToolEmitter::unified_exec(
|
||||
&command,
|
||||
cwd.clone(),
|
||||
@@ -254,14 +259,15 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_command(args: &ExecCommandArgs) -> Vec<String> {
|
||||
let shell = if let Some(shell_str) = &args.shell {
|
||||
get_shell_by_model_provided_path(&PathBuf::from(shell_str))
|
||||
} else {
|
||||
default_user_shell()
|
||||
};
|
||||
fn get_command(args: &ExecCommandArgs, session_shell: Arc<Shell>) -> Vec<String> {
|
||||
if let Some(shell_str) = &args.shell {
|
||||
let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str));
|
||||
shell.shell_snapshot = None;
|
||||
return shell.derive_exec_args(&args.cmd, args.login.unwrap_or(true));
|
||||
}
|
||||
|
||||
shell.derive_exec_args(&args.cmd, args.login)
|
||||
let use_login_shell = args.login.unwrap_or(session_shell.shell_snapshot.is_none());
|
||||
session_shell.derive_exec_args(&args.cmd, use_login_shell)
|
||||
}
|
||||
|
||||
fn format_response(response: &UnifiedExecResponse) -> String {
|
||||
@@ -296,6 +302,8 @@ fn format_response(response: &UnifiedExecResponse) -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::shell::default_user_shell;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_get_command_uses_default_shell_when_unspecified() {
|
||||
@@ -306,7 +314,7 @@ mod tests {
|
||||
|
||||
assert!(args.shell.is_none());
|
||||
|
||||
let command = get_command(&args);
|
||||
let command = get_command(&args, Arc::new(default_user_shell()));
|
||||
|
||||
assert_eq!(command.len(), 3);
|
||||
assert_eq!(command[2], "echo hello");
|
||||
@@ -321,7 +329,7 @@ mod tests {
|
||||
|
||||
assert_eq!(args.shell.as_deref(), Some("/bin/bash"));
|
||||
|
||||
let command = get_command(&args);
|
||||
let command = get_command(&args, Arc::new(default_user_shell()));
|
||||
|
||||
assert_eq!(command[2], "echo hello");
|
||||
}
|
||||
@@ -335,7 +343,7 @@ mod tests {
|
||||
|
||||
assert_eq!(args.shell.as_deref(), Some("powershell"));
|
||||
|
||||
let command = get_command(&args);
|
||||
let command = get_command(&args, Arc::new(default_user_shell()));
|
||||
|
||||
assert_eq!(command[2], "echo hello");
|
||||
}
|
||||
@@ -349,7 +357,7 @@ mod tests {
|
||||
|
||||
assert_eq!(args.shell.as_deref(), Some("cmd"));
|
||||
|
||||
let command = get_command(&args);
|
||||
let command = get_command(&args, Arc::new(default_user_shell()));
|
||||
|
||||
assert_eq!(command[2], "echo hello");
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ pub trait ToolHandler: Send + Sync {
|
||||
)
|
||||
}
|
||||
|
||||
fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
|
||||
async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ impl ToolRegistry {
|
||||
let output_cell = &output_cell;
|
||||
let invocation = invocation;
|
||||
async move {
|
||||
if handler.is_mutating(&invocation) {
|
||||
if handler.is_mutating(&invocation).await {
|
||||
tracing::trace!("waiting for tool gate");
|
||||
invocation.turn.tool_call_gate.wait_ready().await;
|
||||
tracing::trace!("tool gate released");
|
||||
|
||||
Reference in New Issue
Block a user