shell-command: reuse a PowerShell parser process on Windows (#16057)

## Why

`//codex-rs/shell-command:shell-command-unit-tests` became a real
bottleneck in the Windows Bazel lane because repeated calls to
`is_safe_command_windows()` were starting a fresh PowerShell parser
process for every `powershell.exe -Command ...` assertion.

PR #16056 was motivated by that same bottleneck, but its test-only
shortcut was the wrong layer to optimize because it weakened the
end-to-end guarantee that our runtime path really asks PowerShell to
parse the command the way we expect.

This PR attacks the actual cost center instead: it keeps the real
PowerShell parser in the loop, but turns that parser into a long-lived
helper process so both tests and the runtime safe-command path can reuse
it across many requests.

## What Changed

- add `shell-command/src/command_safety/powershell_parser.rs`, which
keeps one mutex-protected parser process per PowerShell executable path
and speaks a simple JSON-over-stdio request/response protocol
- turn `shell-command/src/command_safety/powershell_parser.ps1` into a
long-running parser server with comments explaining the protocol, the
AST-shape restrictions, and why unsupported constructs are rejected
conservatively
- keep request ids and a one-time respawn path so a dead or
desynchronized cached child fails closed instead of silently returning
mixed parser output
- preserve separate parser processes for `powershell.exe` and
`pwsh.exe`, since they do not accept the same language surface
- avoid a direct `PipelineChainAst` type reference in the PowerShell
script so the parser service still runs under Windows PowerShell 5.1 as
well as newer `pwsh`
- make `shell-command/src/command_safety/windows_safe_commands.rs`
delegate to the new parser utility instead of spawning a fresh
PowerShell process for every parse
- add a Windows-only unit test that exercises multiple sequential
requests against the same parser process

## Testing

- adds a Windows-only parser-reuse unit test in `powershell_parser.rs`
- the main end-to-end verification for this change is the Windows CI
lane, because the new service depends on real `powershell.exe` /
`pwsh.exe` behavior
This commit is contained in:
Michael Bolin
2026-03-27 19:33:41 -07:00
committed by GitHub
parent 71923f43a7
commit 142681ef93
4 changed files with 411 additions and 145 deletions

View File

@@ -1,3 +1,5 @@
mod powershell_parser;
pub mod is_dangerous_command;
pub mod is_safe_command;
pub mod windows_safe_commands;

View File

@@ -1,44 +1,98 @@
$ErrorActionPreference = 'Stop'
$ProgressPreference = 'SilentlyContinue'
$payload = $env:CODEX_POWERSHELL_PAYLOAD
if ([string]::IsNullOrEmpty($payload)) {
Write-Output '{"status":"parse_failed"}'
exit 0
}
# Long-lived PowerShell AST parser used by the Rust command-safety layer on Windows.
# The caller starts one child process per PowerShell executable variant and then sends
# newline-delimited JSON requests over stdin:
# { "id": <u64>, "payload": "<base64-encoded UTF-16LE script>" }
# We answer with one compact JSON line per request:
# { "id": <same>, "status": "ok", "commands": [["Get-Content", "foo.txt"]] }
# or:
# { "id": <same>, "status": "parse_failed" | "parse_errors" | "unsupported" }
#
# "unsupported" is intentional: it means the script parsed successfully, but the AST
# included constructs that we conservatively refuse to lower into argv-like command words.
# The Rust side treats that the same way as an unsafe command.
try {
$source =
[System.Text.Encoding]::Unicode.GetString(
[System.Convert]::FromBase64String($payload)
# Use BOM-free UTF-8 on the protocol stream so Rust sees clean JSON lines with no
# leading BOM bytes on the first response.
$utf8 = [System.Text.UTF8Encoding]::new($false)
$stdin = [System.IO.StreamReader]::new([Console]::OpenStandardInput(), $utf8, $false)
$stdout = [System.IO.StreamWriter]::new([Console]::OpenStandardOutput(), $utf8)
$stdout.AutoFlush = $true
function Invoke-ParseRequest {
param($RequestId, $Source)
$tokens = $null
$errors = $null
$ast = $null
try {
$ast = [System.Management.Automation.Language.Parser]::ParseInput(
$Source,
[ref]$tokens,
[ref]$errors
)
} catch {
Write-Output '{"status":"parse_failed"}'
exit 0
} catch {
return @{ id = $RequestId; status = 'parse_failed' }
}
if ($errors.Count -gt 0) {
return @{ id = $RequestId; status = 'parse_errors' }
}
# Only accept AST shapes we can flatten into a list of argv-like command words.
# Anything more dynamic than that becomes "unsupported" instead of being guessed at.
$commands = [System.Collections.ArrayList]::new()
foreach ($statement in $ast.EndBlock.Statements) {
if (-not (Add-CommandsFromPipelineBase $statement $commands)) {
$commands = $null
break
}
}
if ($commands -ne $null) {
$normalized = [System.Collections.ArrayList]::new()
foreach ($cmd in $commands) {
# Convert every successful parse result to an array-of-arrays shape so the Rust
# side can deserialize one uniform representation.
if ($cmd -is [string]) {
$null = $normalized.Add(@($cmd))
continue
}
if ($cmd -is [System.Array] -or $cmd -is [System.Collections.IEnumerable]) {
$null = $normalized.Add(@($cmd))
continue
}
$normalized = $null
break
}
$commands = $normalized
}
if ($commands -eq $null) {
return @{ id = $RequestId; status = 'unsupported' }
}
return @{ id = $RequestId; status = 'ok'; commands = $commands }
}
$tokens = $null
$errors = $null
function Write-Response {
param($Response)
$ast = $null
try {
$ast = [System.Management.Automation.Language.Parser]::ParseInput(
$source,
[ref]$tokens,
[ref]$errors
)
} catch {
Write-Output '{"status":"parse_failed"}'
exit 0
}
if ($errors.Count -gt 0) {
Write-Output '{"status":"parse_errors"}'
exit 0
$stdout.WriteLine(($Response | ConvertTo-Json -Compress -Depth 3))
}
function Convert-CommandElement {
param($element)
# Accept only literal-ish command elements. Variable expansion, subexpressions, splats,
# and other dynamic forms return $null so the whole request becomes unsupported.
if ($element -is [System.Management.Automation.Language.StringConstantExpressionAst]) {
return @($element.Value)
}
@@ -77,6 +131,8 @@ function Convert-PipelineElement {
param($element)
if ($element -is [System.Management.Automation.Language.CommandAst]) {
# Redirections and invocation operators make the command harder to classify safely,
# so reject them rather than trying to normalize them.
if ($element.Redirections.Count -gt 0) {
return $null
}
@@ -104,6 +160,8 @@ function Convert-PipelineElement {
return $null
}
# Allow a parenthesized single pipeline element like "(Get-Content foo.rs -Raw)" so
# the caller still sees the inner command words. More complex expressions stay unsupported.
if ($element.Expression -is [System.Management.Automation.Language.ParenExpressionAst]) {
$innerPipeline = $element.Expression.Pipeline
if ($innerPipeline -and $innerPipeline.PipelineElements.Count -eq 1) {
@@ -156,46 +214,44 @@ function Add-CommandsFromPipelineBase {
return Add-CommandsFromPipelineAst $pipeline $commands
}
if ($pipeline -is [System.Management.Automation.Language.PipelineChainAst]) {
# Windows PowerShell 5.1 does not define PipelineChainAst, so avoid a direct type
# reference here and instead check the runtime type name.
if ($pipeline.GetType().FullName -eq 'System.Management.Automation.Language.PipelineChainAst') {
return Add-CommandsFromPipelineChain $pipeline $commands
}
return $false
}
$commands = [System.Collections.ArrayList]::new()
foreach ($statement in $ast.EndBlock.Statements) {
if (-not (Add-CommandsFromPipelineBase $statement $commands)) {
$commands = $null
break
}
}
if ($commands -ne $null) {
$normalized = [System.Collections.ArrayList]::new()
foreach ($cmd in $commands) {
if ($cmd -is [string]) {
$null = $normalized.Add(@($cmd))
continue
}
if ($cmd -is [System.Array] -or $cmd -is [System.Collections.IEnumerable]) {
$null = $normalized.Add(@($cmd))
continue
}
$normalized = $null
break
# This script stays alive so the Rust caller can amortize PowerShell startup across
# many parse requests. Each request and response is one compact JSON line.
while (($requestLine = $stdin.ReadLine()) -ne $null) {
$request = $null
try {
$request = $requestLine | ConvertFrom-Json
} catch {
Write-Response @{ id = $null; status = 'parse_failed' }
continue
}
$commands = $normalized
}
# We process requests serially, but still echo the id back so the Rust side can
# detect protocol desyncs instead of silently trusting mixed stdout.
$requestId = $request.id
$payload = $request.payload
if ([string]::IsNullOrEmpty($payload)) {
Write-Response @{ id = $requestId; status = 'parse_failed' }
continue
}
$result = if ($commands -eq $null) {
@{ status = 'unsupported' }
} else {
@{ status = 'ok'; commands = $commands }
}
try {
$source =
[System.Text.Encoding]::Unicode.GetString(
[System.Convert]::FromBase64String($payload)
)
} catch {
Write-Response @{ id = $requestId; status = 'parse_failed' }
continue
}
,$result | ConvertTo-Json -Depth 3
Write-Response (Invoke-ParseRequest -RequestId $requestId -Source $source)
}

View File

@@ -0,0 +1,289 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::io::BufRead;
use std::io::BufReader;
use std::io::ErrorKind;
use std::io::Write;
use std::process::Child;
use std::process::ChildStdin;
use std::process::ChildStdout;
use std::process::Command;
use std::process::Stdio;
use std::sync::LazyLock;
use std::sync::Mutex;
use std::sync::PoisonError;
const POWERSHELL_PARSER_SCRIPT: &str = include_str!("powershell_parser.ps1");
/// Cache one long-lived parser process per executable path so repeated safety checks reuse
/// PowerShell startup work while still consulting the real parser every time.
///
/// We keep the cache behind one mutex because each child process speaks a simple
/// request/response protocol over a single stdin/stdout pair, so callers targeting the same
/// executable must serialize access anyway.
pub(super) fn parse_with_powershell_ast(executable: &str, script: &str) -> PowershellParseOutcome {
static PARSER_PROCESSES: LazyLock<Mutex<HashMap<String, PowershellParserProcess>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
let mut parser_processes = PARSER_PROCESSES
.lock()
.unwrap_or_else(PoisonError::into_inner);
parse_with_cached_process(&mut parser_processes, executable, script)
}
#[derive(Debug, PartialEq, Eq)]
pub(super) enum PowershellParseOutcome {
Commands(Vec<Vec<String>>),
Unsupported,
Failed,
}
fn parse_with_cached_process(
parser_processes: &mut HashMap<String, PowershellParserProcess>,
executable: &str,
script: &str,
) -> PowershellParseOutcome {
// `powershell.exe` and `pwsh.exe` do not accept the same language surface, so each
// executable keeps its own parser process and request stream.
let parser_key = executable.to_string();
for attempt in 0..=1 {
if !parser_processes.contains_key(&parser_key) {
match PowershellParserProcess::spawn(executable) {
Ok(process) => {
parser_processes.insert(parser_key.clone(), process);
}
Err(_) => return PowershellParseOutcome::Failed,
}
}
let Some(parser_process) = parser_processes.get_mut(&parser_key) else {
return PowershellParseOutcome::Failed;
};
let parse_result = parser_process.parse(script);
match parse_result {
Ok(outcome) => return outcome,
Err(_) if attempt == 0 => {
// The common failure mode here is that a previously cached child exited or its
// stdio stream became unusable between requests. Drop that process and retry once
// with a fresh child before giving up.
parser_processes.remove(&parser_key);
}
Err(_) => return PowershellParseOutcome::Failed,
}
}
PowershellParseOutcome::Failed
}
fn encode_powershell_base64(script: &str) -> String {
let mut utf16 = Vec::with_capacity(script.len() * 2);
for unit in script.encode_utf16() {
utf16.extend_from_slice(&unit.to_le_bytes());
}
BASE64_STANDARD.encode(utf16)
}
fn encoded_parser_script() -> &'static str {
static ENCODED: LazyLock<String> =
LazyLock::new(|| encode_powershell_base64(POWERSHELL_PARSER_SCRIPT));
&ENCODED
}
struct PowershellParserProcess {
child: Child,
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
// Request ids are monotonic within one child process so the caller can detect protocol
// desynchronization if stdout is contaminated or the child is unexpectedly replaced.
next_request_id: u64,
}
impl PowershellParserProcess {
fn spawn(executable: &str) -> std::io::Result<Self> {
let mut child = Command::new(executable)
.args([
"-NoLogo",
"-NoProfile",
"-NonInteractive",
"-EncodedCommand",
encoded_parser_script(),
])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()?;
let stdin = match take_child_stdin(&mut child) {
Ok(stdin) => stdin,
Err(error) => {
kill_child(&mut child);
return Err(error);
}
};
let stdout = match take_child_stdout(&mut child) {
Ok(stdout) => stdout,
Err(error) => {
kill_child(&mut child);
return Err(error);
}
};
Ok(Self {
child,
stdin,
stdout,
next_request_id: 0,
})
}
fn parse(&mut self, script: &str) -> std::io::Result<PowershellParseOutcome> {
let request = PowershellParserRequest {
id: self.next_request_id,
payload: encode_powershell_base64(script),
};
self.next_request_id = self.next_request_id.wrapping_add(1);
let mut request_json = serialize_request(&request)?;
request_json.push('\n');
self.stdin.write_all(request_json.as_bytes())?;
self.stdin.flush()?;
let mut response_line = String::new();
if self.stdout.read_line(&mut response_line)? == 0 {
return Err(std::io::Error::new(
ErrorKind::UnexpectedEof,
"PowerShell parser closed stdout",
));
}
let response = deserialize_response(&response_line)?;
// Requests are serialized today; the id still catches protocol desyncs if stdout is
// contaminated or the child process is unexpectedly replaced mid-request. That turns an
// ambiguous parser result into a hard failure so the caller can discard the cached child.
if response.id != request.id {
return Err(std::io::Error::new(
ErrorKind::InvalidData,
format!(
"PowerShell parser returned response id {} for request {}",
response.id, request.id
),
));
}
Ok(response.into_outcome())
}
}
impl Drop for PowershellParserProcess {
fn drop(&mut self) {
kill_child(&mut self.child);
}
}
fn take_child_stdin(child: &mut Child) -> std::io::Result<ChildStdin> {
child.stdin.take().ok_or_else(|| {
std::io::Error::new(
ErrorKind::BrokenPipe,
"PowerShell parser child did not expose stdin",
)
})
}
fn take_child_stdout(child: &mut Child) -> std::io::Result<BufReader<ChildStdout>> {
child.stdout.take().map(BufReader::new).ok_or_else(|| {
std::io::Error::new(
ErrorKind::BrokenPipe,
"PowerShell parser child did not expose stdout",
)
})
}
fn serialize_request(request: &PowershellParserRequest) -> std::io::Result<String> {
serde_json::to_string(request).map_err(|error| {
std::io::Error::new(
ErrorKind::InvalidData,
format!("failed to serialize PowerShell parser request: {error}"),
)
})
}
fn deserialize_response(response_line: &str) -> std::io::Result<PowershellParserResponse> {
serde_json::from_str(response_line).map_err(|error| {
std::io::Error::new(
ErrorKind::InvalidData,
format!("failed to parse PowerShell parser response: {error}"),
)
})
}
#[derive(Serialize)]
struct PowershellParserRequest {
id: u64,
payload: String,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct PowershellParserResponse {
id: u64,
status: String,
commands: Option<Vec<Vec<String>>>,
}
impl PowershellParserResponse {
fn into_outcome(self) -> PowershellParseOutcome {
match self.status.as_str() {
"ok" => self
.commands
.filter(|commands| {
!commands.is_empty()
&& commands
.iter()
.all(|cmd| !cmd.is_empty() && cmd.iter().all(|word| !word.is_empty()))
})
.map(PowershellParseOutcome::Commands)
.unwrap_or(PowershellParseOutcome::Unsupported),
"unsupported" => PowershellParseOutcome::Unsupported,
_ => PowershellParseOutcome::Failed,
}
}
}
fn kill_child(child: &mut Child) {
let _ = child.kill();
let _ = child.wait();
}
#[cfg(all(test, windows))]
mod tests {
use super::*;
use crate::powershell::try_find_powershell_executable_blocking;
use pretty_assertions::assert_eq;
#[test]
fn parser_process_handles_multiple_requests() {
let Some(powershell) = try_find_powershell_executable_blocking() else {
return;
};
let powershell = powershell.as_path().to_str().unwrap();
let mut parser = PowershellParserProcess::spawn(powershell).unwrap();
let first = parser.parse("Get-Content 'foo bar'").unwrap();
assert_eq!(
first,
PowershellParseOutcome::Commands(vec![vec![
"Get-Content".to_string(),
"foo bar".to_string(),
]]),
);
let second = parser.parse("Write-Output foo | Measure-Object").unwrap();
assert_eq!(
second,
PowershellParseOutcome::Commands(vec![
vec!["Write-Output".to_string(), "foo".to_string()],
vec!["Measure-Object".to_string()],
]),
);
}
}

View File

@@ -1,12 +1,7 @@
use crate::command_safety::is_dangerous_command::git_global_option_requires_prompt;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use serde::Deserialize;
use crate::command_safety::powershell_parser::PowershellParseOutcome;
use crate::command_safety::powershell_parser::parse_with_powershell_ast;
use std::path::Path;
use std::process::Command;
use std::sync::LazyLock;
const POWERSHELL_PARSER_SCRIPT: &str = include_str!("powershell_parser.ps1");
/// On Windows, we conservatively allow only clearly read-only PowerShell invocations
/// that match a small safelist. Anything else (including direct CMD commands) is unsafe.
@@ -122,82 +117,6 @@ fn is_powershell_executable(exe: &str) -> bool {
)
}
/// Attempts to parse PowerShell using the real PowerShell parser, returning every pipeline element
/// as a flat argv vector when possible. If parsing fails or the AST includes unsupported constructs,
/// we conservatively reject the command instead of trying to split it manually.
fn parse_with_powershell_ast(executable: &str, script: &str) -> PowershellParseOutcome {
let encoded_script = encode_powershell_base64(script);
let encoded_parser_script = encoded_parser_script();
match Command::new(executable)
.args([
"-NoLogo",
"-NoProfile",
"-NonInteractive",
"-EncodedCommand",
encoded_parser_script,
])
.env("CODEX_POWERSHELL_PAYLOAD", &encoded_script)
.output()
{
Ok(output) if output.status.success() => {
if let Ok(result) =
serde_json::from_slice::<PowershellParserOutput>(output.stdout.as_slice())
{
result.into_outcome()
} else {
PowershellParseOutcome::Failed
}
}
_ => PowershellParseOutcome::Failed,
}
}
fn encode_powershell_base64(script: &str) -> String {
let mut utf16 = Vec::with_capacity(script.len() * 2);
for unit in script.encode_utf16() {
utf16.extend_from_slice(&unit.to_le_bytes());
}
BASE64_STANDARD.encode(utf16)
}
fn encoded_parser_script() -> &'static str {
static ENCODED: LazyLock<String> =
LazyLock::new(|| encode_powershell_base64(POWERSHELL_PARSER_SCRIPT));
&ENCODED
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct PowershellParserOutput {
status: String,
commands: Option<Vec<Vec<String>>>,
}
impl PowershellParserOutput {
fn into_outcome(self) -> PowershellParseOutcome {
match self.status.as_str() {
"ok" => self
.commands
.filter(|commands| {
!commands.is_empty()
&& commands
.iter()
.all(|cmd| !cmd.is_empty() && cmd.iter().all(|word| !word.is_empty()))
})
.map(PowershellParseOutcome::Commands)
.unwrap_or(PowershellParseOutcome::Unsupported),
"unsupported" => PowershellParseOutcome::Unsupported,
_ => PowershellParseOutcome::Failed,
}
}
}
enum PowershellParseOutcome {
Commands(Vec<Vec<String>>),
Unsupported,
Failed,
}
fn join_arguments_as_script(args: &[String]) -> String {
let mut words = Vec::with_capacity(args.len());
if let Some((first, rest)) = args.split_first() {