refactor: rule traits

This commit is contained in:
kevin zhao
2025-11-13 13:00:11 -05:00
parent f7fa9c5c0f
commit dc76907771
5 changed files with 63 additions and 59 deletions

View File

@@ -12,3 +12,4 @@ pub use policy::Evaluation;
pub use policy::Policy;
pub use rule::Rule;
pub use rule::RuleMatch;
pub use rule::RuleRef;

View File

@@ -21,7 +21,7 @@ use crate::policy::validate_match_examples;
use crate::rule::PatternToken;
use crate::rule::PrefixPattern;
use crate::rule::PrefixRule;
use crate::rule::Rule;
use crate::rule::RuleRef;
// todo: support parsing multiple policies
pub struct PolicyParser;
@@ -48,7 +48,7 @@ impl PolicyParser {
#[derive(Debug, ProvidesStaticType)]
struct PolicyBuilder {
rules_by_program: Mutex<MultiMap<String, Rule>>,
rules_by_program: Mutex<MultiMap<String, RuleRef>>,
}
impl PolicyBuilder {
@@ -58,7 +58,7 @@ impl PolicyBuilder {
}
}
fn add_rule(&self, rule: Rule) {
fn add_rule(&self, rule: RuleRef) {
self.rules_by_program
.lock()
.insert(rule.program().to_string(), rule);
@@ -214,17 +214,17 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
let rest: Arc<[PatternToken]> = remaining_tokens.to_vec().into();
let rules: Vec<Rule> = first_token
let rules: Vec<RuleRef> = first_token
.alternatives()
.iter()
.map(|head| {
Rule::Prefix(PrefixRule {
Arc::new(PrefixRule {
pattern: PrefixPattern {
first: Arc::from(head.as_str()),
rest: rest.clone(),
},
decision,
})
}) as RuleRef
})
.collect();

View File

@@ -1,8 +1,8 @@
use crate::decision::Decision;
use crate::error::Error;
use crate::error::Result;
use crate::rule::Rule;
use crate::rule::RuleMatch;
use crate::rule::RuleRef;
use multimap::MultiMap;
use serde::Deserialize;
use serde::Serialize;
@@ -10,15 +10,15 @@ use shlex::try_join;
#[derive(Clone, Debug)]
pub struct Policy {
rules_by_program: MultiMap<String, Rule>,
rules_by_program: MultiMap<String, RuleRef>,
}
impl Policy {
pub fn new(rules_by_program: MultiMap<String, Rule>) -> Self {
pub fn new(rules_by_program: MultiMap<String, RuleRef>) -> Self {
Self { rules_by_program }
}
pub fn rules(&self) -> &MultiMap<String, Rule> {
pub fn rules(&self) -> &MultiMap<String, RuleRef> {
&self.rules_by_program
}
@@ -60,7 +60,7 @@ impl Evaluation {
}
/// Count how many rules match each provided example and error if any example is unmatched.
pub(crate) fn validate_match_examples(rules: &[Rule], matches: &[Vec<String>]) -> Result<()> {
pub(crate) fn validate_match_examples(rules: &[RuleRef], matches: &[Vec<String>]) -> Result<()> {
let match_counts = rules.iter().fold(vec![0; matches.len()], |counts, rule| {
counts
.iter()

View File

@@ -4,6 +4,8 @@ use crate::error::Result;
use serde::Deserialize;
use serde::Serialize;
use shlex::try_join;
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;
/// Matches a single command token, either a fixed string or one of several allowed alternatives.
@@ -77,24 +79,20 @@ pub struct PrefixRule {
pub decision: Decision,
}
impl PrefixRule {
pub fn matches(&self, cmd: &[String]) -> Option<RuleMatch> {
self.pattern
.matches_prefix(cmd)
.map(|matched_prefix| RuleMatch::PrefixRuleMatch {
matched_prefix,
decision: self.decision,
})
}
pub trait Rule: Any + Debug + Send + Sync {
fn program(&self) -> &str;
pub fn validate_matches(&self, matches: &[Vec<String>]) -> Vec<bool> {
fn matches(&self, cmd: &[String]) -> Option<RuleMatch>;
/// Return a boolean for each example indicating whether this rule matches it.
fn validate_matches(&self, matches: &[Vec<String>]) -> Vec<bool> {
matches
.iter()
.map(|example| self.matches(example).is_some())
.collect()
}
pub fn validate_not_matches(&self, not_matches: &[Vec<String>]) -> Result<()> {
fn validate_not_matches(&self, not_matches: &[Vec<String>]) -> Result<()> {
for example in not_matches {
if self.matches(example).is_some() {
return Err(Error::ExampleDidMatch {
@@ -108,34 +106,19 @@ impl PrefixRule {
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Rule {
Prefix(PrefixRule),
}
pub type RuleRef = Arc<dyn Rule>;
impl Rule {
pub fn program(&self) -> &str {
match self {
Self::Prefix(rule) => rule.pattern.first.as_ref(),
}
impl Rule for PrefixRule {
fn program(&self) -> &str {
self.pattern.first.as_ref()
}
pub fn matches(&self, cmd: &[String]) -> Option<RuleMatch> {
match self {
Self::Prefix(rule) => rule.matches(cmd),
}
}
/// Return a boolean for each example indicating whether this rule matches it.
pub fn validate_matches(&self, matches: &[Vec<String>]) -> Vec<bool> {
match self {
Self::Prefix(rule) => rule.validate_matches(matches),
}
}
pub fn validate_not_matches(&self, not_matches: &[Vec<String>]) -> Result<()> {
match self {
Self::Prefix(rule) => rule.validate_not_matches(not_matches),
}
fn matches(&self, cmd: &[String]) -> Option<RuleMatch> {
self.pattern
.matches_prefix(cmd)
.map(|matched_prefix| RuleMatch::PrefixRuleMatch {
matched_prefix,
decision: self.decision,
})
}
}

View File

@@ -1,10 +1,11 @@
use std::any::Any;
use std::sync::Arc;
use codex_execpolicy2::Decision;
use codex_execpolicy2::Evaluation;
use codex_execpolicy2::PolicyParser;
use codex_execpolicy2::Rule;
use codex_execpolicy2::RuleMatch;
use codex_execpolicy2::RuleRef;
use codex_execpolicy2::rule::PatternToken;
use codex_execpolicy2::rule::PrefixPattern;
use codex_execpolicy2::rule::PrefixRule;
@@ -14,6 +15,25 @@ fn tokens(cmd: &[&str]) -> Vec<String> {
cmd.iter().map(std::string::ToString::to_string).collect()
}
#[derive(Clone, Debug, Eq, PartialEq)]
enum RuleSnapshot {
Prefix(PrefixRule),
}
fn rule_snapshots(rules: &[RuleRef]) -> Vec<RuleSnapshot> {
rules
.iter()
.map(|rule| {
let rule_any = rule.as_ref() as &dyn Any;
if let Some(prefix_rule) = rule_any.downcast_ref::<PrefixRule>() {
RuleSnapshot::Prefix(prefix_rule.clone())
} else {
panic!("unexpected rule type in RuleRef: {rule:?}");
}
})
.collect()
}
#[test]
fn basic_match() {
let policy_src = r#"
@@ -45,27 +65,27 @@ prefix_rule(
"#;
let policy = PolicyParser::parse("test.codexpolicy", policy_src).expect("parse policy");
let bash_rules = policy.rules().get_vec("bash").expect("bash rules");
let sh_rules = policy.rules().get_vec("sh").expect("sh rules");
let bash_rules = rule_snapshots(policy.rules().get_vec("bash").expect("bash rules"));
let sh_rules = rule_snapshots(policy.rules().get_vec("sh").expect("sh rules"));
assert_eq!(
vec![Rule::Prefix(PrefixRule {
vec![RuleSnapshot::Prefix(PrefixRule {
pattern: PrefixPattern {
first: Arc::from("bash"),
rest: vec![PatternToken::Alts(vec!["-c".to_string(), "-l".to_string()])].into(),
},
decision: Decision::Allow,
})],
bash_rules.clone()
bash_rules
);
assert_eq!(
vec![Rule::Prefix(PrefixRule {
vec![RuleSnapshot::Prefix(PrefixRule {
pattern: PrefixPattern {
first: Arc::from("sh"),
rest: vec![PatternToken::Alts(vec!["-c".to_string(), "-l".to_string()])].into(),
},
decision: Decision::Allow,
})],
sh_rules.clone()
sh_rules
);
let bash_eval = policy.check(&tokens(&["bash", "-c", "echo", "hi"]));
@@ -102,23 +122,23 @@ prefix_rule(
"#;
let policy = PolicyParser::parse("test.codexpolicy", policy_src).expect("parse policy");
let rules = policy.rules().get_vec("npm").expect("npm rules");
let rules = rule_snapshots(policy.rules().get_vec("npm").expect("npm rules"));
assert_eq!(
vec![Rule::Prefix(PrefixRule {
vec![RuleSnapshot::Prefix(PrefixRule {
pattern: PrefixPattern {
first: Arc::from("npm"),
rest: vec![
PatternToken::Alts(vec!["i".to_string(), "install".to_string()]),
PatternToken::Alts(vec![
"--legacy-peer-deps".to_string(),
"--no-save".to_string()
"--no-save".to_string(),
]),
]
.into(),
},
decision: Decision::Allow,
})],
rules.clone()
rules
);
let npm_i = policy.check(&tokens(&["npm", "i", "--legacy-peer-deps"]));