Refactor requirement vals with source

This commit is contained in:
gt-oai
2026-02-04 00:01:14 +00:00
parent 08926a3fb7
commit 31c9bddabc
3 changed files with 93 additions and 24 deletions

View File

@@ -1588,9 +1588,9 @@ impl Config {
model_provider_id,
model_provider,
cwd: resolved_cwd,
approval_policy: constrained_approval_policy,
sandbox_policy: constrained_sandbox_policy,
enforce_residency,
approval_policy: constrained_approval_policy.value,
sandbox_policy: constrained_sandbox_policy.value,
enforce_residency: enforce_residency.value,
did_user_set_custom_approval_policy_or_sandbox_mode,
forced_auto_mode_downgraded_on_windows,
shell_environment_policy,

View File

@@ -44,25 +44,57 @@ impl fmt::Display for RequirementSource {
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ConstrainedWithSource<T> {
pub value: Constrained<T>,
pub source: Option<RequirementSource>,
}
impl<T> ConstrainedWithSource<T> {
pub fn new(value: Constrained<T>, source: Option<RequirementSource>) -> Self {
Self { value, source }
}
}
impl<T> std::ops::Deref for ConstrainedWithSource<T> {
type Target = Constrained<T>;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<T> std::ops::DerefMut for ConstrainedWithSource<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
}
/// Normalized version of [`ConfigRequirementsToml`] after deserialization and
/// normalization.
#[derive(Debug, Clone, PartialEq)]
pub struct ConfigRequirements {
pub approval_policy: Constrained<AskForApproval>,
pub sandbox_policy: Constrained<SandboxPolicy>,
pub approval_policy: ConstrainedWithSource<AskForApproval>,
pub sandbox_policy: ConstrainedWithSource<SandboxPolicy>,
pub mcp_servers: Option<Sourced<BTreeMap<String, McpServerRequirement>>>,
pub(crate) exec_policy: Option<Sourced<RequirementsExecPolicy>>,
pub enforce_residency: Constrained<Option<ResidencyRequirement>>,
pub enforce_residency: ConstrainedWithSource<Option<ResidencyRequirement>>,
}
impl Default for ConfigRequirements {
fn default() -> Self {
Self {
approval_policy: Constrained::allow_any_from_default(),
sandbox_policy: Constrained::allow_any(SandboxPolicy::ReadOnly),
approval_policy: ConstrainedWithSource::new(
Constrained::allow_any_from_default(),
None,
),
sandbox_policy: ConstrainedWithSource::new(
Constrained::allow_any(SandboxPolicy::ReadOnly),
None,
),
mcp_servers: None,
exec_policy: None,
enforce_residency: Constrained::allow_any(None),
enforce_residency: ConstrainedWithSource::new(Constrained::allow_any(None), None),
}
}
}
@@ -228,7 +260,7 @@ impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
enforce_residency,
} = toml;
let approval_policy: Constrained<AskForApproval> = match allowed_approval_policies {
let approval_policy = match allowed_approval_policies {
Some(Sourced {
value: policies,
source: requirement_source,
@@ -237,7 +269,8 @@ impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
return Err(ConstraintError::empty_field("allowed_approval_policies"));
};
Constrained::new(initial_value, move |candidate| {
let requirement_source_for_error = requirement_source.clone();
let constrained = Constrained::new(initial_value, move |candidate| {
if policies.contains(candidate) {
Ok(())
} else {
@@ -245,12 +278,13 @@ impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
field_name: "approval_policy",
candidate: format!("{candidate:?}"),
allowed: format!("{policies:?}"),
requirement_source: requirement_source.clone(),
requirement_source: requirement_source_for_error.clone(),
})
}
})?
})?;
ConstrainedWithSource::new(constrained, Some(requirement_source))
}
None => Constrained::allow_any_from_default(),
None => ConstrainedWithSource::new(Constrained::allow_any_from_default(), None),
};
// TODO(gt): `ConfigRequirementsToml` should let the author specify the
@@ -261,7 +295,7 @@ impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
// additional parameters. Ultimately, we should expand the config
// format to allow specifying those parameters.
let default_sandbox_policy = SandboxPolicy::ReadOnly;
let sandbox_policy: Constrained<SandboxPolicy> = match allowed_sandbox_modes {
let sandbox_policy = match allowed_sandbox_modes {
Some(Sourced {
value: modes,
source: requirement_source,
@@ -275,7 +309,8 @@ impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
});
};
Constrained::new(default_sandbox_policy, move |candidate| {
let requirement_source_for_error = requirement_source.clone();
let constrained = Constrained::new(default_sandbox_policy, move |candidate| {
let mode = match candidate {
SandboxPolicy::ReadOnly => SandboxModeRequirement::ReadOnly,
SandboxPolicy::WorkspaceWrite { .. } => {
@@ -293,12 +328,15 @@ impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
field_name: "sandbox_mode",
candidate: format!("{mode:?}"),
allowed: format!("{modes:?}"),
requirement_source: requirement_source.clone(),
requirement_source: requirement_source_for_error.clone(),
})
}
})?
})?;
ConstrainedWithSource::new(constrained, Some(requirement_source))
}
None => {
ConstrainedWithSource::new(Constrained::allow_any(default_sandbox_policy), None)
}
None => Constrained::allow_any(default_sandbox_policy),
};
let exec_policy = match rules {
Some(Sourced { value, source }) => {
@@ -313,13 +351,14 @@ impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
None => None,
};
let enforce_residency: Constrained<Option<ResidencyRequirement>> = match enforce_residency {
let enforce_residency = match enforce_residency {
Some(Sourced {
value: residency,
source: requirement_source,
}) => {
let required = Some(residency);
Constrained::new(required, move |candidate| {
let requirement_source_for_error = requirement_source.clone();
let constrained = Constrained::new(required, move |candidate| {
if candidate == &required {
Ok(())
} else {
@@ -327,12 +366,13 @@ impl TryFrom<ConfigRequirementsWithSources> for ConfigRequirements {
field_name: "enforce_residency",
candidate: format!("{candidate:?}"),
allowed: format!("{required:?}"),
requirement_source: requirement_source.clone(),
requirement_source: requirement_source_for_error.clone(),
})
}
})?
})?;
ConstrainedWithSource::new(constrained, Some(requirement_source))
}
None => Constrained::allow_any(None),
None => ConstrainedWithSource::new(Constrained::allow_any(None), None),
};
Ok(ConfigRequirements {
approval_policy,
@@ -563,6 +603,34 @@ mod tests {
Ok(())
}
#[test]
fn constrained_fields_store_requirement_source() -> Result<()> {
let source: ConfigRequirementsToml = from_str(
r#"
allowed_approval_policies = ["on-request"]
allowed_sandbox_modes = ["read-only"]
enforce_residency = "us"
"#,
)?;
let source_location = RequirementSource::CloudRequirements;
let mut target = ConfigRequirementsWithSources::default();
target.merge_unset_fields(source_location.clone(), source);
let requirements = ConfigRequirements::try_from(target)?;
assert_eq!(
requirements.approval_policy.source,
Some(source_location.clone())
);
assert_eq!(
requirements.sandbox_policy.source,
Some(source_location.clone())
);
assert_eq!(requirements.enforce_residency.source, Some(source_location));
Ok(())
}
#[test]
fn deserialize_allowed_approval_policies() -> Result<()> {
let toml_str = r#"

View File

@@ -34,6 +34,7 @@ use toml::Value as TomlValue;
pub use cloud_requirements::CloudRequirementsLoader;
pub use config_requirements::ConfigRequirements;
pub use config_requirements::ConfigRequirementsToml;
pub use config_requirements::ConstrainedWithSource;
pub use config_requirements::McpServerIdentity;
pub use config_requirements::McpServerRequirement;
pub use config_requirements::RequirementSource;