refactor rules so no more cartesian product

This commit is contained in:
kevin zhao
2025-11-10 15:20:50 -08:00
parent eea9bff1fb
commit 7e79c4dc5b
6 changed files with 198 additions and 77 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -1193,6 +1193,7 @@ name = "codex-execpolicy2"
version = "0.0.0"
dependencies = [
"anyhow",
"multimap",
"serde",
"serde_json",
"starlark",

View File

@@ -19,3 +19,4 @@ serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
starlark = { workspace = true }
thiserror = { workspace = true }
multimap = { workspace = true }

View File

@@ -1,5 +1,6 @@
use std::cell::RefCell;
use multimap::MultiMap;
use starlark::any::ProvidesStaticType;
use starlark::environment::GlobalsBuilder;
use starlark::environment::Module;
@@ -15,6 +16,8 @@ use starlark::values::none::NoneType;
use crate::decision::Decision;
use crate::error::Error;
use crate::error::Result;
use crate::rule::PatternToken;
use crate::rule::PrefixPattern;
use crate::rule::Rule;
pub struct PolicyParser {
@@ -51,14 +54,14 @@ impl PolicyParser {
#[derive(Debug, ProvidesStaticType)]
struct PolicyBuilder {
rules: RefCell<Vec<Rule>>,
rules_by_program: RefCell<MultiMap<String, Rule>>,
next_auto_id: RefCell<i64>,
}
impl PolicyBuilder {
fn new() -> Self {
Self {
rules: RefCell::new(Vec::new()),
rules_by_program: RefCell::new(MultiMap::new()),
next_auto_id: RefCell::new(0),
}
}
@@ -71,68 +74,81 @@ impl PolicyBuilder {
}
fn add_rule(&self, rule: Rule) {
self.rules.borrow_mut().push(rule);
self.rules_by_program
.borrow_mut()
.insert(rule.pattern.first.clone(), rule);
}
fn build(&self) -> crate::policy::Policy {
crate::policy::Policy::new(self.rules.borrow().clone())
crate::policy::Policy::new(self.rules_by_program.borrow().clone())
}
}
#[derive(Debug)]
enum PatternPart {
Single(String),
Alts(Vec<String>),
struct ParsedPattern {
heads: Vec<String>,
tail: Vec<PatternToken>,
}
fn expand_pattern(parts: &[PatternPart]) -> Vec<Vec<String>> {
let mut acc: Vec<Vec<String>> = vec![Vec::new()];
for part in parts {
let alts: Vec<String> = match part {
PatternPart::Single(s) => vec![s.clone()],
PatternPart::Alts(v) => v.clone(),
};
let mut next = Vec::new();
for prefix in &acc {
for alt in &alts {
let mut combined = prefix.clone();
combined.push(alt.clone());
next.push(combined);
}
}
acc = next;
fn parse_pattern<'v>(pattern: UnpackList<Value<'v>>) -> Result<ParsedPattern> {
let mut items = pattern.items.into_iter();
let first = items
.next()
.ok_or_else(|| Error::InvalidPattern("pattern cannot be empty".to_string()))?;
let heads = parse_first_token(first)?;
let mut tail = Vec::new();
for item in items {
tail.push(parse_tail_token(item)?);
}
acc
Ok(ParsedPattern { heads, tail })
}
fn parse_pattern<'v>(pattern: UnpackList<Value<'v>>) -> Result<Vec<Vec<String>>> {
let mut parts = Vec::new();
for item in pattern.items {
if let Some(s) = item.unpack_str() {
parts.push(PatternPart::Single(s.to_string()));
continue;
}
fn parse_first_token<'v>(value: Value<'v>) -> Result<Vec<String>> {
if let Some(s) = value.unpack_str() {
return Ok(vec![s.to_string()]);
}
if let Some(list) = ListRef::from_value(value) {
let mut alts = Vec::new();
if let Some(list) = ListRef::from_value(item) {
for value in list.content() {
let s = value.unpack_str().ok_or_else(|| {
Error::InvalidPattern("pattern alternative must be a string".to_string())
})?;
alts.push(s.to_string());
}
} else {
return Err(Error::InvalidPattern(
"pattern element must be a string or list of strings".to_string(),
));
for value in list.content() {
let s = value.unpack_str().ok_or_else(|| {
Error::InvalidPattern("pattern alternative must be a string".to_string())
})?;
alts.push(s.to_string());
}
if alts.is_empty() {
return Err(Error::InvalidPattern(
"pattern alternatives cannot be empty".to_string(),
));
}
parts.push(PatternPart::Alts(alts));
return Ok(alts);
}
Ok(expand_pattern(&parts))
Err(Error::InvalidPattern(
"pattern element must be a string or list of strings".to_string(),
))
}
fn parse_tail_token<'v>(value: Value<'v>) -> Result<PatternToken> {
if let Some(s) = value.unpack_str() {
return Ok(PatternToken::Single(s.to_string()));
}
if let Some(list) = ListRef::from_value(value) {
let mut alts = Vec::new();
for value in list.content() {
let s = value.unpack_str().ok_or_else(|| {
Error::InvalidPattern("pattern alternative must be a string".to_string())
})?;
alts.push(s.to_string());
}
if alts.is_empty() {
return Err(Error::InvalidPattern(
"pattern alternatives cannot be empty".to_string(),
));
}
return Ok(PatternToken::Alts(alts));
}
Err(Error::InvalidPattern(
"pattern element must be a string or list of strings".to_string(),
))
}
fn parse_examples<'v>(examples: UnpackList<Value<'v>>) -> Result<Vec<Vec<String>>> {
@@ -173,7 +189,7 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
None => Decision::Allow,
};
let prefixes = parse_pattern(pattern)?;
let parsed_pattern = parse_pattern(pattern)?;
let positive_examples: Vec<Vec<String>> =
r#match.map(parse_examples).transpose()?.unwrap_or_default();
@@ -193,13 +209,6 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
builder.alloc_id()
});
let rule = Rule {
id: id.clone(),
prefixes,
decision,
};
rule.validate_examples(&positive_examples, &negative_examples)?;
#[expect(clippy::unwrap_used)]
let builder = eval
.extra
@@ -207,7 +216,19 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
.unwrap()
.downcast_ref::<PolicyBuilder>()
.unwrap();
builder.add_rule(rule);
for head in &parsed_pattern.heads {
let rule = Rule {
id: id.clone(),
pattern: PrefixPattern {
first: head.clone(),
tail: parsed_pattern.tail.clone(),
},
decision,
};
rule.validate_examples(&positive_examples, &negative_examples)?;
builder.add_rule(rule);
}
Ok(NoneType)
}
}

View File

@@ -1,11 +1,12 @@
use crate::decision::Decision;
use crate::rule::Rule;
use multimap::MultiMap;
use serde::Deserialize;
use serde::Serialize;
#[derive(Clone, Debug)]
pub struct Policy {
rules: Vec<Rule>,
rules_by_program: MultiMap<String, Rule>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
@@ -15,18 +16,22 @@ pub struct Evaluation {
}
impl Policy {
pub fn new(rules: Vec<Rule>) -> Self {
Self { rules }
pub fn new(rules_by_program: MultiMap<String, Rule>) -> Self {
Self { rules_by_program }
}
pub fn rules(&self) -> &[Rule] {
&self.rules
pub fn rules(&self) -> &MultiMap<String, Rule> {
&self.rules_by_program
}
pub fn evaluate(&self, cmd: &[String]) -> Option<Evaluation> {
let first = cmd.first()?;
let Some(rules) = self.rules_by_program.get_vec(first) else {
return None;
};
let mut matched_rules: Vec<crate::rule::RuleMatch> = Vec::new();
let mut best_decision: Option<Decision> = None;
for rule in &self.rules {
for rule in rules {
if let Some(matched) = rule.matches(cmd) {
let decision = match best_decision {
None => matched.decision,

View File

@@ -4,10 +4,51 @@ use crate::error::Result;
use serde::Deserialize;
use serde::Serialize;
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum PatternToken {
Single(String),
Alts(Vec<String>),
}
impl PatternToken {
fn matches(&self, token: &str) -> bool {
match self {
Self::Single(expected) => expected == token,
Self::Alts(alternatives) => alternatives.iter().any(|alt| alt == token),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PrefixPattern {
pub first: String,
pub tail: Vec<PatternToken>,
}
impl PrefixPattern {
pub fn len(&self) -> usize {
self.tail.len() + 1
}
pub fn matches_prefix(&self, cmd: &[String]) -> Option<Vec<String>> {
if cmd.len() < self.len() || cmd[0] != self.first {
return None;
}
for (pattern_token, cmd_token) in self.tail.iter().zip(&cmd[1..self.len()]) {
if !pattern_token.matches(cmd_token) {
return None;
}
}
Some(cmd[..self.len()].to_vec())
}
}
#[derive(Clone, Debug)]
pub struct Rule {
pub id: String,
pub prefixes: Vec<Vec<String>>,
pub pattern: PrefixPattern,
pub decision: Decision,
}
@@ -20,23 +61,13 @@ pub struct RuleMatch {
impl Rule {
pub fn matches(&self, cmd: &[String]) -> Option<RuleMatch> {
for prefix in &self.prefixes {
if prefix.len() > cmd.len() {
continue;
}
if cmd
.iter()
.zip(prefix)
.all(|(cmd_tok, prefix_tok)| cmd_tok == prefix_tok)
{
return Some(RuleMatch {
rule_id: self.id.clone(),
matched_prefix: prefix.clone(),
decision: self.decision,
});
}
}
None
self.pattern
.matches_prefix(cmd)
.map(|matched_prefix| RuleMatch {
rule_id: self.id.clone(),
matched_prefix,
decision: self.decision,
})
}
pub fn validate_examples(

View File

@@ -1,6 +1,7 @@
use codex_execpolicy2::Decision;
use codex_execpolicy2::PolicyParser;
use codex_execpolicy2::RuleMatch;
use codex_execpolicy2::rule::PatternToken;
fn tokens(cmd: &[&str]) -> Vec<String> {
cmd.iter().map(|token| token.to_string()).collect()
@@ -52,6 +53,67 @@ prefix_rule(
assert!(policy.evaluate(&no_match).is_none());
}
#[test]
fn only_first_token_alias_expands_to_multiple_rules() {
let policy_src = r#"
prefix_rule(
id = "shell",
pattern = [["bash", "sh"], ["-c", "-l"]],
)
"#;
let parser = PolicyParser::new("test.policy", policy_src);
let policy = parser.parse().expect("parse policy");
let bash_rules = policy.rules().get_vec("bash").expect("bash rules");
let sh_rules = policy.rules().get_vec("sh").expect("sh rules");
assert_eq!(bash_rules.len(), 1);
assert_eq!(sh_rules.len(), 1);
for (cmd, prefix) in [
(
tokens(&["bash", "-c", "echo", "hi"]),
tokens(&["bash", "-c"]),
),
(tokens(&["sh", "-l", "echo", "hi"]), tokens(&["sh", "-l"])),
] {
let eval = policy.evaluate(&cmd).expect("match");
assert_eq!(eval.matched_rules[0].matched_prefix, prefix);
}
}
#[test]
fn tail_aliases_are_not_cartesian_expanded() {
let policy_src = r#"
prefix_rule(
id = "npm_install_variants",
pattern = ["npm", ["i", "install"], ["--legacy-peer-deps", "--no-save"]],
)
"#;
let parser = PolicyParser::new("test.policy", policy_src);
let policy = parser.parse().expect("parse policy");
let rules = policy.rules().get_vec("npm").expect("npm rules");
assert_eq!(rules.len(), 1);
let rule = &rules[0];
assert_eq!(
rule.pattern.tail,
vec![
PatternToken::Alts(vec!["i".to_string(), "install".to_string()]),
PatternToken::Alts(vec![
"--legacy-peer-deps".to_string(),
"--no-save".to_string()
]),
],
);
for cmd in [
tokens(&["npm", "i", "--legacy-peer-deps"]),
tokens(&["npm", "install", "--no-save", "leftpad"]),
] {
assert!(policy.evaluate(&cmd).is_some());
}
}
#[test]
fn match_and_not_match_examples_are_enforced() {
let policy_src = r#"