mirror of
https://github.com/openai/codex.git
synced 2026-04-28 08:34:54 +00:00
refactor rules so no more cartesian product
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use std::cell::RefCell;
|
||||
|
||||
use multimap::MultiMap;
|
||||
use starlark::any::ProvidesStaticType;
|
||||
use starlark::environment::GlobalsBuilder;
|
||||
use starlark::environment::Module;
|
||||
@@ -15,6 +16,8 @@ use starlark::values::none::NoneType;
|
||||
use crate::decision::Decision;
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
use crate::rule::PatternToken;
|
||||
use crate::rule::PrefixPattern;
|
||||
use crate::rule::Rule;
|
||||
|
||||
pub struct PolicyParser {
|
||||
@@ -51,14 +54,14 @@ impl PolicyParser {
|
||||
|
||||
#[derive(Debug, ProvidesStaticType)]
|
||||
struct PolicyBuilder {
|
||||
rules: RefCell<Vec<Rule>>,
|
||||
rules_by_program: RefCell<MultiMap<String, Rule>>,
|
||||
next_auto_id: RefCell<i64>,
|
||||
}
|
||||
|
||||
impl PolicyBuilder {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
rules: RefCell::new(Vec::new()),
|
||||
rules_by_program: RefCell::new(MultiMap::new()),
|
||||
next_auto_id: RefCell::new(0),
|
||||
}
|
||||
}
|
||||
@@ -71,68 +74,81 @@ impl PolicyBuilder {
|
||||
}
|
||||
|
||||
fn add_rule(&self, rule: Rule) {
|
||||
self.rules.borrow_mut().push(rule);
|
||||
self.rules_by_program
|
||||
.borrow_mut()
|
||||
.insert(rule.pattern.first.clone(), rule);
|
||||
}
|
||||
|
||||
fn build(&self) -> crate::policy::Policy {
|
||||
crate::policy::Policy::new(self.rules.borrow().clone())
|
||||
crate::policy::Policy::new(self.rules_by_program.borrow().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PatternPart {
|
||||
Single(String),
|
||||
Alts(Vec<String>),
|
||||
struct ParsedPattern {
|
||||
heads: Vec<String>,
|
||||
tail: Vec<PatternToken>,
|
||||
}
|
||||
|
||||
fn expand_pattern(parts: &[PatternPart]) -> Vec<Vec<String>> {
|
||||
let mut acc: Vec<Vec<String>> = vec![Vec::new()];
|
||||
for part in parts {
|
||||
let alts: Vec<String> = match part {
|
||||
PatternPart::Single(s) => vec![s.clone()],
|
||||
PatternPart::Alts(v) => v.clone(),
|
||||
};
|
||||
let mut next = Vec::new();
|
||||
for prefix in &acc {
|
||||
for alt in &alts {
|
||||
let mut combined = prefix.clone();
|
||||
combined.push(alt.clone());
|
||||
next.push(combined);
|
||||
}
|
||||
}
|
||||
acc = next;
|
||||
fn parse_pattern<'v>(pattern: UnpackList<Value<'v>>) -> Result<ParsedPattern> {
|
||||
let mut items = pattern.items.into_iter();
|
||||
let first = items
|
||||
.next()
|
||||
.ok_or_else(|| Error::InvalidPattern("pattern cannot be empty".to_string()))?;
|
||||
let heads = parse_first_token(first)?;
|
||||
let mut tail = Vec::new();
|
||||
for item in items {
|
||||
tail.push(parse_tail_token(item)?);
|
||||
}
|
||||
acc
|
||||
Ok(ParsedPattern { heads, tail })
|
||||
}
|
||||
|
||||
fn parse_pattern<'v>(pattern: UnpackList<Value<'v>>) -> Result<Vec<Vec<String>>> {
|
||||
let mut parts = Vec::new();
|
||||
for item in pattern.items {
|
||||
if let Some(s) = item.unpack_str() {
|
||||
parts.push(PatternPart::Single(s.to_string()));
|
||||
continue;
|
||||
}
|
||||
fn parse_first_token<'v>(value: Value<'v>) -> Result<Vec<String>> {
|
||||
if let Some(s) = value.unpack_str() {
|
||||
return Ok(vec![s.to_string()]);
|
||||
}
|
||||
if let Some(list) = ListRef::from_value(value) {
|
||||
let mut alts = Vec::new();
|
||||
if let Some(list) = ListRef::from_value(item) {
|
||||
for value in list.content() {
|
||||
let s = value.unpack_str().ok_or_else(|| {
|
||||
Error::InvalidPattern("pattern alternative must be a string".to_string())
|
||||
})?;
|
||||
alts.push(s.to_string());
|
||||
}
|
||||
} else {
|
||||
return Err(Error::InvalidPattern(
|
||||
"pattern element must be a string or list of strings".to_string(),
|
||||
));
|
||||
for value in list.content() {
|
||||
let s = value.unpack_str().ok_or_else(|| {
|
||||
Error::InvalidPattern("pattern alternative must be a string".to_string())
|
||||
})?;
|
||||
alts.push(s.to_string());
|
||||
}
|
||||
if alts.is_empty() {
|
||||
return Err(Error::InvalidPattern(
|
||||
"pattern alternatives cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
parts.push(PatternPart::Alts(alts));
|
||||
return Ok(alts);
|
||||
}
|
||||
Ok(expand_pattern(&parts))
|
||||
Err(Error::InvalidPattern(
|
||||
"pattern element must be a string or list of strings".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_tail_token<'v>(value: Value<'v>) -> Result<PatternToken> {
|
||||
if let Some(s) = value.unpack_str() {
|
||||
return Ok(PatternToken::Single(s.to_string()));
|
||||
}
|
||||
if let Some(list) = ListRef::from_value(value) {
|
||||
let mut alts = Vec::new();
|
||||
for value in list.content() {
|
||||
let s = value.unpack_str().ok_or_else(|| {
|
||||
Error::InvalidPattern("pattern alternative must be a string".to_string())
|
||||
})?;
|
||||
alts.push(s.to_string());
|
||||
}
|
||||
if alts.is_empty() {
|
||||
return Err(Error::InvalidPattern(
|
||||
"pattern alternatives cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
return Ok(PatternToken::Alts(alts));
|
||||
}
|
||||
Err(Error::InvalidPattern(
|
||||
"pattern element must be a string or list of strings".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_examples<'v>(examples: UnpackList<Value<'v>>) -> Result<Vec<Vec<String>>> {
|
||||
@@ -173,7 +189,7 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
|
||||
None => Decision::Allow,
|
||||
};
|
||||
|
||||
let prefixes = parse_pattern(pattern)?;
|
||||
let parsed_pattern = parse_pattern(pattern)?;
|
||||
|
||||
let positive_examples: Vec<Vec<String>> =
|
||||
r#match.map(parse_examples).transpose()?.unwrap_or_default();
|
||||
@@ -193,13 +209,6 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
|
||||
builder.alloc_id()
|
||||
});
|
||||
|
||||
let rule = Rule {
|
||||
id: id.clone(),
|
||||
prefixes,
|
||||
decision,
|
||||
};
|
||||
rule.validate_examples(&positive_examples, &negative_examples)?;
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let builder = eval
|
||||
.extra
|
||||
@@ -207,7 +216,19 @@ fn policy_builtins(builder: &mut GlobalsBuilder) {
|
||||
.unwrap()
|
||||
.downcast_ref::<PolicyBuilder>()
|
||||
.unwrap();
|
||||
builder.add_rule(rule);
|
||||
|
||||
for head in &parsed_pattern.heads {
|
||||
let rule = Rule {
|
||||
id: id.clone(),
|
||||
pattern: PrefixPattern {
|
||||
first: head.clone(),
|
||||
tail: parsed_pattern.tail.clone(),
|
||||
},
|
||||
decision,
|
||||
};
|
||||
rule.validate_examples(&positive_examples, &negative_examples)?;
|
||||
builder.add_rule(rule);
|
||||
}
|
||||
Ok(NoneType)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user