add unit tests and re-add crate back to cargo

This commit is contained in:
viyatb-oai
2025-12-23 18:19:42 -08:00
parent 9d473922e3
commit dc063ff890
6 changed files with 467 additions and 2 deletions

View File

@@ -136,6 +136,7 @@ fn resolve_addr(url: &str, default_port: u16) -> SocketAddr {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct SocketAddressParts<'a> {
host: &'a str,
port: u16,
@@ -170,7 +171,10 @@ fn parse_host_port(url: &str, default_port: u16) -> SocketAddressParts<'_> {
return SocketAddressParts { host, port };
}
if let Some((host, port)) = host_port.rsplit_once(':')
// Only treat `host:port` as such when there's a single `:`. This avoids
// accidentally interpreting unbracketed IPv6 addresses as `host:port`.
if host_port.bytes().filter(|b| *b == b':').count() == 1
&& let Some((host, port)) = host_port.rsplit_once(':')
&& let Ok(port) = port.parse::<u16>()
{
return SocketAddressParts { host, port };
@@ -181,3 +185,127 @@ fn parse_host_port(url: &str, default_port: u16) -> SocketAddressParts<'_> {
port: default_port,
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn parse_host_port_defaults_for_empty_string() {
assert_eq!(
parse_host_port("", 1234),
SocketAddressParts {
host: "127.0.0.1",
port: 1234,
}
);
}
#[test]
fn parse_host_port_defaults_for_whitespace() {
assert_eq!(
parse_host_port(" ", 5555),
SocketAddressParts {
host: "127.0.0.1",
port: 5555,
}
);
}
#[test]
fn parse_host_port_parses_host_port_without_scheme() {
assert_eq!(
parse_host_port("127.0.0.1:8080", 3128),
SocketAddressParts {
host: "127.0.0.1",
port: 8080,
}
);
}
#[test]
fn parse_host_port_parses_host_port_with_scheme_and_path() {
assert_eq!(
parse_host_port("http://example.com:8080/some/path", 3128),
SocketAddressParts {
host: "example.com",
port: 8080,
}
);
}
#[test]
fn parse_host_port_strips_userinfo() {
assert_eq!(
parse_host_port("http://user:pass@host.example:5555", 3128),
SocketAddressParts {
host: "host.example",
port: 5555,
}
);
}
#[test]
fn parse_host_port_parses_ipv6_with_brackets() {
assert_eq!(
parse_host_port("http://[::1]:9999", 3128),
SocketAddressParts {
host: "::1",
port: 9999,
}
);
}
#[test]
fn parse_host_port_does_not_treat_unbracketed_ipv6_as_host_port() {
assert_eq!(
parse_host_port("2001:db8::1", 3128),
SocketAddressParts {
host: "2001:db8::1",
port: 3128,
}
);
}
#[test]
fn parse_host_port_falls_back_to_default_port_when_port_is_invalid() {
assert_eq!(
parse_host_port("example.com:notaport", 3128),
SocketAddressParts {
host: "example.com:notaport",
port: 3128,
}
);
}
#[test]
fn resolve_addr_maps_localhost_to_loopback() {
assert_eq!(
resolve_addr("localhost", 3128),
"127.0.0.1:3128".parse().unwrap()
);
}
#[test]
fn resolve_addr_parses_ip_literals() {
assert_eq!(resolve_addr("1.2.3.4", 80), "1.2.3.4:80".parse().unwrap());
}
#[test]
fn resolve_addr_parses_ipv6_literals() {
assert_eq!(
resolve_addr("http://[::1]:8080", 3128),
"[::1]:8080".parse().unwrap()
);
}
#[test]
fn resolve_addr_falls_back_to_loopback_for_hostnames() {
assert_eq!(
resolve_addr("http://example.com:5555", 3128),
"127.0.0.1:5555".parse().unwrap()
);
}
}

View File

@@ -26,5 +26,76 @@ pub fn normalize_host(host: &str) -> String {
{
return host[1..end].to_ascii_lowercase();
}
host.split(':').next().unwrap_or("").to_ascii_lowercase()
// The proxy stack should typically hand us a host without a port, but be
// defensive and strip `:port` when there is exactly one `:`.
if host.bytes().filter(|b| *b == b':').count() == 1 {
return host
.split(':')
.next()
.unwrap_or_default()
.to_ascii_lowercase();
}
// Avoid mangling unbracketed IPv6 literals.
host.to_ascii_lowercase()
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn method_allowed_full_allows_everything() {
assert!(method_allowed(NetworkMode::Full, "GET"));
assert!(method_allowed(NetworkMode::Full, "POST"));
assert!(method_allowed(NetworkMode::Full, "CONNECT"));
}
#[test]
fn method_allowed_limited_allows_only_safe_methods() {
assert!(method_allowed(NetworkMode::Limited, "GET"));
assert!(method_allowed(NetworkMode::Limited, "HEAD"));
assert!(method_allowed(NetworkMode::Limited, "OPTIONS"));
assert!(!method_allowed(NetworkMode::Limited, "POST"));
assert!(!method_allowed(NetworkMode::Limited, "CONNECT"));
}
#[test]
fn is_loopback_host_handles_localhost_variants() {
assert!(is_loopback_host("localhost"));
assert!(is_loopback_host("localhost."));
assert!(is_loopback_host("LOCALHOST"));
assert!(!is_loopback_host("notlocalhost"));
}
#[test]
fn is_loopback_host_handles_ip_literals() {
assert!(is_loopback_host("127.0.0.1"));
assert!(is_loopback_host("::1"));
assert!(!is_loopback_host("1.2.3.4"));
}
#[test]
fn normalize_host_lowercases_and_trims() {
assert_eq!(normalize_host(" ExAmPlE.CoM "), "example.com");
}
#[test]
fn normalize_host_strips_port_for_host_port() {
assert_eq!(normalize_host("example.com:1234"), "example.com");
}
#[test]
fn normalize_host_preserves_unbracketed_ipv6() {
assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
}
#[test]
fn normalize_host_strips_brackets_for_ipv6() {
assert_eq!(normalize_host("[::1]"), "::1");
assert_eq!(normalize_host("[::1]:443"), "::1");
}
}

View File

@@ -568,3 +568,264 @@ fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]
fn unix_timestamp() -> i64 {
OffsetDateTime::now_utc().unix_timestamp()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::NetworkPolicy;
use crate::config::NetworkProxyConfig;
use pretty_assertions::assert_eq;
fn app_state_for_policy(policy: NetworkPolicy) -> AppState {
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
mode: NetworkMode::Full,
policy,
..NetworkProxyConfig::default()
},
};
let allow_set = compile_globset(&config.network_proxy.policy.allowed_domains).unwrap();
let deny_set = compile_globset(&config.network_proxy.policy.denied_domains).unwrap();
let state = ConfigState {
config,
mtime: None,
allow_set,
deny_set,
mitm: None,
cfg_path: PathBuf::from("/nonexistent/config.toml"),
blocked: VecDeque::new(),
};
AppState {
state: Arc::new(RwLock::new(state)),
}
}
#[tokio::test]
async fn host_blocked_denied_wins_over_allowed() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
denied_domains: vec!["example.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("example.com").await.unwrap(),
(true, "denied".to_string())
);
}
#[tokio::test]
async fn host_blocked_requires_allowlist_match() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("example.com").await.unwrap(),
(false, String::new())
);
assert_eq!(
state.host_blocked("not-example.com").await.unwrap(),
(true, "not_allowed".to_string())
);
}
#[tokio::test]
async fn host_blocked_expands_apex_for_wildcard_patterns() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["*.openai.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("openai.com").await.unwrap(),
(false, String::new())
);
assert_eq!(
state.host_blocked("api.openai.com").await.unwrap(),
(false, String::new())
);
}
#[tokio::test]
async fn host_blocked_rejects_loopback_when_local_binding_disabled() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("127.0.0.1").await.unwrap(),
(true, "not_allowed_local".to_string())
);
assert_eq!(
state.host_blocked("localhost").await.unwrap(),
(true, "not_allowed_local".to_string())
);
}
#[tokio::test]
async fn host_blocked_allows_loopback_when_explicitly_allowlisted_and_local_binding_disabled() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["localhost".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("localhost").await.unwrap(),
(false, String::new())
);
}
#[test]
fn validate_policy_against_constraints_disallows_widening_allowed_domains() {
let constraints = NetworkProxyConstraints {
allowed_domains: Some(vec!["example.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
policy: NetworkPolicy {
allowed_domains: vec!["example.com".to_string(), "evil.com".to_string()],
..NetworkPolicy::default()
},
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_requires_managed_denied_domains_entries() {
let constraints = NetworkProxyConstraints {
denied_domains: Some(vec!["evil.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
policy: NetworkPolicy {
denied_domains: vec![],
..NetworkPolicy::default()
},
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_enabling_when_managed_disabled() {
let constraints = NetworkProxyConstraints {
enabled: Some(false),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_allow_local_binding_when_managed_disabled() {
let constraints = NetworkProxyConstraints {
allow_local_binding: Some(false),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
policy: NetworkPolicy {
allow_local_binding: true,
..NetworkPolicy::default()
},
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn compile_globset_is_case_insensitive() {
let patterns = vec!["ExAmPle.CoM".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("example.com"));
assert!(set.is_match("EXAMPLE.COM"));
}
#[test]
fn compile_globset_expands_apex_for_wildcard_patterns() {
let patterns = vec!["*.openai.com".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("openai.com"));
assert!(set.is_match("api.openai.com"));
assert!(!set.is_match("evilopenai.com"));
}
#[test]
fn compile_globset_dedupes_patterns_without_changing_behavior() {
let patterns = vec!["example.com".to_string(), "example.com".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("example.com"));
assert!(set.is_match("EXAMPLE.COM"));
assert!(!set.is_match("not-example.com"));
}
#[test]
fn compile_globset_rejects_invalid_patterns() {
let patterns = vec!["[".to_string()];
assert!(compile_globset(&patterns).is_err());
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn unix_socket_allowlist_is_respected_on_macos() {
let socket_path = "/tmp/example.sock".to_string();
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_unix_sockets: vec![socket_path.clone()],
..NetworkPolicy::default()
});
assert!(state.is_unix_socket_allowed(&socket_path).await.unwrap());
assert!(
!state
.is_unix_socket_allowed("/tmp/not-allowed.sock")
.await
.unwrap()
);
}
#[cfg(not(target_os = "macos"))]
#[tokio::test]
async fn unix_socket_allowlist_is_rejected_on_non_macos() {
let socket_path = "/tmp/example.sock".to_string();
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_unix_sockets: vec![socket_path.clone()],
..NetworkPolicy::default()
});
assert!(!state.is_unix_socket_allowed(&socket_path).await.unwrap());
}
}