mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
refactor rules so no more cartesian product
This commit is contained in:
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -1193,6 +1193,7 @@ name = "codex-execpolicy2"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"multimap",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"starlark",
|
||||
|
||||
@@ -19,3 +19,4 @@ serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
starlark = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
multimap = { workspace = true }
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::cell::RefCell;
|
||||
|
||||
use multimap::MultiMap;
|
||||
use starlark::any::ProvidesStaticType;
|
||||
use starlark::environment::GlobalsBuilder;
|
||||
use starlark::environment::Module;
|
||||
@@ -15,6 +16,8 @@ use starlark::values::none::NoneType;
|
||||
use crate::decision::Decision;
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
use crate::rule::PatternToken;
|
||||
use crate::rule::PrefixPattern;
|
||||
use crate::rule::Rule;
|
||||
|
||||
pub struct PolicyParser {
|
||||
@@ -51,14 +54,14 @@ impl PolicyParser {
|
||||
|
||||
#[derive(Debug, ProvidesStaticType)]
|
||||
struct PolicyBuilder {
|
||||
rules: RefCell<Vec<Rule>>,
|
||||
rules_by_program: RefCell<MultiMap<String, Rule>>,
|
||||
next_auto_id: RefCell<i64>,
|
||||
}
|
||||
|
||||
impl PolicyBuilder {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
rules: RefCell::new(Vec::new()),
|
||||
rules_by_program: RefCell::new(MultiMap::new()),
|
||||
next_auto_id: RefCell::new(0),
|
||||
}
|
||||
}
|
||||
@@ -71,68 +74,81 @@ impl PolicyBuilder {
|
||||
}
|
||||
|
||||
fn add_rule(&self, rule: Rule) {
|
||||
self.rules.borrow_mut().push(rule);
|
||||
self.rules_by_program
|
||||
.borrow_mut()
|
||||
.insert(rule.pattern.first.clone(), rule);
|
||||
}
|
||||
|
||||
fn build(&self) -> crate::policy::Policy {
|
||||
crate::policy::Policy::new(self.rules.borrow().clone())
|
||||
crate::policy::Policy::new(self.rules_by_program.borrow().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PatternPart {
|
||||
Single(String),
|
||||
Alts(Vec<String>),
|
||||
struct ParsedPattern {
|
||||
heads: Vec<String>,
|
||||
tail: Vec<PatternToken>,
|
||||
}
|
||||
|
||||
fn expand_pattern(parts: &[PatternPart]) -> Vec<Vec<String>> {
|
||||
let mut acc: Vec<Vec<String>> = vec![Vec::new()];
|
||||
for part in parts {
|
||||
let alts: Vec<String> = match part {
|
||||
PatternPart::Single(s) => vec![s.clone()],
|
||||
PatternPart::Alts(v) => v.clone(),
|
||||
};
|
||||
let mut next = Vec::new();
|
||||
for prefix in &acc {
|
||||
for alt in &alts {
|
||||
let mut combined = prefix.clone();
|
||||
combined.push(alt.clone());
|
||||
next.push(combined);
|
||||
}
|
||||
}
|
||||
acc = next;
|
||||
fn parse_pattern<'v>(pattern: UnpackList<Value<'v>>) -> Result<ParsedPattern> {
|
||||
let mut items = pattern.items.into_iter();
|
||||
let first = items
|
||||
.next()
|
||||
.ok_or_else(|| Error::InvalidPattern("pattern cannot be empty".to_string()))?;
|
||||
let heads = parse_first_token(first)?;
|
||||
let mut tail = Vec::new();
|
||||
for item in items {
|
||||
tail.push(parse_tail_token(item)?);
|
||||
}
|
||||
acc
|
||||
Ok(ParsedPattern { heads, tail })
|
||||
}
|
||||
|
||||
fn parse_pattern<'v>(pattern: UnpackList<Value<'v>>) -> Result<Vec<Vec<String>>> {
|
||||
let mut parts = Vec::new();
|
||||
for item in pattern.items {
|
||||
if let Some(s) = item.unpack_str() {
|
||||
parts.push(PatternPart::Single(s.to_string()));
|
||||
continue;
|
||||
}
|
||||
fn parse_first_token<'v>(value: Value<'v>) -> Result<Vec<String>> {
|
||||
if let Some(s) = value.unpack_str() {
|
||||
return Ok(vec![s.to_string()]);
|
||||
}
|
||||
if let Some(list) = ListRef::from_value(value) {
|
||||
let mut alts = Vec::new();
|
||||
if let Some(list) = ListRef::from_value(item) {
|
||||
for value in list.content() {
|
||||
let s = value.unpack_str().ok_or_else(|| {
|
||||
Error::InvalidPattern("pattern alternative must be a string".to_string())
|
||||
})?;
|
||||
alts.push(s.to_string());
|
||||
}
|
||||
} else {
|
||||
return Err(Error::InvalidPattern(
|
||||
"pattern element must be a string or list of strings".to_string(),
|
||||
));
|
||||
for value in list.content() {
|
||||
let s = value.unpack_str().ok_or_else(|| {
|
||||
Error::InvalidPattern("pattern alternative must be a string".to_string())
|
||||
})?;
|
||||
alts.push(s.to_string());
|
||||
}
|
||||
if alts.is_empty() {
|
||||
return Err(Error::InvalidPattern(
|
||||
"pattern alternatives cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
parts.push(PatternPart::Alts(alts));
|
||||
return Ok(alts);
|
||||
}
|
||||
Ok(expand_pattern(&parts))
|
||||
Err(Error::InvalidPattern(
|
||||
"pattern element must be a string or list of strings".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_tail_token<'v>(value: Value<'v>) -> Result<PatternToken> {
|
||||
if let Some(s) = value.unpack_str() {
|
||||
return Ok(PatternToken::Single(s.to_string()));
|
||||
}
|
||||
if let Some(list) = ListRef::from_value(value) {
|
||||
let mut alts = Vec::new();
|
||||
for value in list.content() {
|
||||
let s = value.unpack_str().ok_or_else(|| {
|
||||
Error::InvalidPattern("pattern alternative must be a string".to_string())
|
||||
})?;
|
||||
alts.push(s.to_string());
|
||||
}
|
||||
if alts.is_empty() {
|
||||
return Err(Error::InvalidPattern(
|
||||
"pattern alternatives cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
return Ok(PatternToken::Alts(alts));
|
||||
}
|
||||
Err(Error::InvalidPattern(
|
||||
"pattern element must be a string or list of strings".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_examples<'v>(examples: UnpackList<Value<'v>>) -> Result<Vec<Vec<String>>> {
|
||||
@@ -173,7 +189,7 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
|
||||
None => Decision::Allow,
|
||||
};
|
||||
|
||||
let prefixes = parse_pattern(pattern)?;
|
||||
let parsed_pattern = parse_pattern(pattern)?;
|
||||
|
||||
let positive_examples: Vec<Vec<String>> =
|
||||
r#match.map(parse_examples).transpose()?.unwrap_or_default();
|
||||
@@ -193,13 +209,6 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
|
||||
builder.alloc_id()
|
||||
});
|
||||
|
||||
let rule = Rule {
|
||||
id: id.clone(),
|
||||
prefixes,
|
||||
decision,
|
||||
};
|
||||
rule.validate_examples(&positive_examples, &negative_examples)?;
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let builder = eval
|
||||
.extra
|
||||
@@ -207,7 +216,19 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
|
||||
.unwrap()
|
||||
.downcast_ref::<PolicyBuilder>()
|
||||
.unwrap();
|
||||
builder.add_rule(rule);
|
||||
|
||||
for head in &parsed_pattern.heads {
|
||||
let rule = Rule {
|
||||
id: id.clone(),
|
||||
pattern: PrefixPattern {
|
||||
first: head.clone(),
|
||||
tail: parsed_pattern.tail.clone(),
|
||||
},
|
||||
decision,
|
||||
};
|
||||
rule.validate_examples(&positive_examples, &negative_examples)?;
|
||||
builder.add_rule(rule);
|
||||
}
|
||||
Ok(NoneType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use crate::decision::Decision;
|
||||
use crate::rule::Rule;
|
||||
use multimap::MultiMap;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Policy {
|
||||
rules: Vec<Rule>,
|
||||
rules_by_program: MultiMap<String, Rule>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
@@ -15,18 +16,22 @@ pub struct Evaluation {
|
||||
}
|
||||
|
||||
impl Policy {
|
||||
pub fn new(rules: Vec<Rule>) -> Self {
|
||||
Self { rules }
|
||||
pub fn new(rules_by_program: MultiMap<String, Rule>) -> Self {
|
||||
Self { rules_by_program }
|
||||
}
|
||||
|
||||
pub fn rules(&self) -> &[Rule] {
|
||||
&self.rules
|
||||
pub fn rules(&self) -> &MultiMap<String, Rule> {
|
||||
&self.rules_by_program
|
||||
}
|
||||
|
||||
pub fn evaluate(&self, cmd: &[String]) -> Option<Evaluation> {
|
||||
let first = cmd.first()?;
|
||||
let Some(rules) = self.rules_by_program.get_vec(first) else {
|
||||
return None;
|
||||
};
|
||||
let mut matched_rules: Vec<crate::rule::RuleMatch> = Vec::new();
|
||||
let mut best_decision: Option<Decision> = None;
|
||||
for rule in &self.rules {
|
||||
for rule in rules {
|
||||
if let Some(matched) = rule.matches(cmd) {
|
||||
let decision = match best_decision {
|
||||
None => matched.decision,
|
||||
|
||||
@@ -4,10 +4,51 @@ use crate::error::Result;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum PatternToken {
|
||||
Single(String),
|
||||
Alts(Vec<String>),
|
||||
}
|
||||
|
||||
impl PatternToken {
|
||||
fn matches(&self, token: &str) -> bool {
|
||||
match self {
|
||||
Self::Single(expected) => expected == token,
|
||||
Self::Alts(alternatives) => alternatives.iter().any(|alt| alt == token),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct PrefixPattern {
|
||||
pub first: String,
|
||||
pub tail: Vec<PatternToken>,
|
||||
}
|
||||
|
||||
impl PrefixPattern {
|
||||
pub fn len(&self) -> usize {
|
||||
self.tail.len() + 1
|
||||
}
|
||||
|
||||
pub fn matches_prefix(&self, cmd: &[String]) -> Option<Vec<String>> {
|
||||
if cmd.len() < self.len() || cmd[0] != self.first {
|
||||
return None;
|
||||
}
|
||||
|
||||
for (pattern_token, cmd_token) in self.tail.iter().zip(&cmd[1..self.len()]) {
|
||||
if !pattern_token.matches(cmd_token) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(cmd[..self.len()].to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Rule {
|
||||
pub id: String,
|
||||
pub prefixes: Vec<Vec<String>>,
|
||||
pub pattern: PrefixPattern,
|
||||
pub decision: Decision,
|
||||
}
|
||||
|
||||
@@ -20,23 +61,13 @@ pub struct RuleMatch {
|
||||
|
||||
impl Rule {
|
||||
pub fn matches(&self, cmd: &[String]) -> Option<RuleMatch> {
|
||||
for prefix in &self.prefixes {
|
||||
if prefix.len() > cmd.len() {
|
||||
continue;
|
||||
}
|
||||
if cmd
|
||||
.iter()
|
||||
.zip(prefix)
|
||||
.all(|(cmd_tok, prefix_tok)| cmd_tok == prefix_tok)
|
||||
{
|
||||
return Some(RuleMatch {
|
||||
rule_id: self.id.clone(),
|
||||
matched_prefix: prefix.clone(),
|
||||
decision: self.decision,
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
self.pattern
|
||||
.matches_prefix(cmd)
|
||||
.map(|matched_prefix| RuleMatch {
|
||||
rule_id: self.id.clone(),
|
||||
matched_prefix,
|
||||
decision: self.decision,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn validate_examples(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use codex_execpolicy2::Decision;
|
||||
use codex_execpolicy2::PolicyParser;
|
||||
use codex_execpolicy2::RuleMatch;
|
||||
use codex_execpolicy2::rule::PatternToken;
|
||||
|
||||
fn tokens(cmd: &[&str]) -> Vec<String> {
|
||||
cmd.iter().map(|token| token.to_string()).collect()
|
||||
@@ -52,6 +53,67 @@ prefix_rule(
|
||||
assert!(policy.evaluate(&no_match).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn only_first_token_alias_expands_to_multiple_rules() {
|
||||
let policy_src = r#"
|
||||
prefix_rule(
|
||||
id = "shell",
|
||||
pattern = [["bash", "sh"], ["-c", "-l"]],
|
||||
)
|
||||
"#;
|
||||
let parser = PolicyParser::new("test.policy", policy_src);
|
||||
let policy = parser.parse().expect("parse policy");
|
||||
|
||||
let bash_rules = policy.rules().get_vec("bash").expect("bash rules");
|
||||
let sh_rules = policy.rules().get_vec("sh").expect("sh rules");
|
||||
assert_eq!(bash_rules.len(), 1);
|
||||
assert_eq!(sh_rules.len(), 1);
|
||||
|
||||
for (cmd, prefix) in [
|
||||
(
|
||||
tokens(&["bash", "-c", "echo", "hi"]),
|
||||
tokens(&["bash", "-c"]),
|
||||
),
|
||||
(tokens(&["sh", "-l", "echo", "hi"]), tokens(&["sh", "-l"])),
|
||||
] {
|
||||
let eval = policy.evaluate(&cmd).expect("match");
|
||||
assert_eq!(eval.matched_rules[0].matched_prefix, prefix);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tail_aliases_are_not_cartesian_expanded() {
|
||||
let policy_src = r#"
|
||||
prefix_rule(
|
||||
id = "npm_install_variants",
|
||||
pattern = ["npm", ["i", "install"], ["--legacy-peer-deps", "--no-save"]],
|
||||
)
|
||||
"#;
|
||||
let parser = PolicyParser::new("test.policy", policy_src);
|
||||
let policy = parser.parse().expect("parse policy");
|
||||
|
||||
let rules = policy.rules().get_vec("npm").expect("npm rules");
|
||||
assert_eq!(rules.len(), 1);
|
||||
let rule = &rules[0];
|
||||
assert_eq!(
|
||||
rule.pattern.tail,
|
||||
vec![
|
||||
PatternToken::Alts(vec!["i".to_string(), "install".to_string()]),
|
||||
PatternToken::Alts(vec![
|
||||
"--legacy-peer-deps".to_string(),
|
||||
"--no-save".to_string()
|
||||
]),
|
||||
],
|
||||
);
|
||||
|
||||
for cmd in [
|
||||
tokens(&["npm", "i", "--legacy-peer-deps"]),
|
||||
tokens(&["npm", "install", "--no-save", "leftpad"]),
|
||||
] {
|
||||
assert!(policy.evaluate(&cmd).is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn match_and_not_match_examples_are_enforced() {
|
||||
let policy_src = r#"
|
||||
|
||||
Reference in New Issue
Block a user