mirror of
https://github.com/openai/codex.git
synced 2026-04-30 01:16:54 +00:00
add unit tests and re-add crate back to cargo
This commit is contained in:
@@ -136,6 +136,7 @@ fn resolve_addr(url: &str, default_port: u16) -> SocketAddr {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct SocketAddressParts<'a> {
|
||||
host: &'a str,
|
||||
port: u16,
|
||||
@@ -170,7 +171,10 @@ fn parse_host_port(url: &str, default_port: u16) -> SocketAddressParts<'_> {
|
||||
return SocketAddressParts { host, port };
|
||||
}
|
||||
|
||||
if let Some((host, port)) = host_port.rsplit_once(':')
|
||||
// Only treat `host:port` as such when there's a single `:`. This avoids
|
||||
// accidentally interpreting unbracketed IPv6 addresses as `host:port`.
|
||||
if host_port.bytes().filter(|b| *b == b':').count() == 1
|
||||
&& let Some((host, port)) = host_port.rsplit_once(':')
|
||||
&& let Ok(port) = port.parse::<u16>()
|
||||
{
|
||||
return SocketAddressParts { host, port };
|
||||
@@ -181,3 +185,127 @@ fn parse_host_port(url: &str, default_port: u16) -> SocketAddressParts<'_> {
|
||||
port: default_port,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_defaults_for_empty_string() {
|
||||
assert_eq!(
|
||||
parse_host_port("", 1234),
|
||||
SocketAddressParts {
|
||||
host: "127.0.0.1",
|
||||
port: 1234,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_defaults_for_whitespace() {
|
||||
assert_eq!(
|
||||
parse_host_port(" ", 5555),
|
||||
SocketAddressParts {
|
||||
host: "127.0.0.1",
|
||||
port: 5555,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_host_port_without_scheme() {
|
||||
assert_eq!(
|
||||
parse_host_port("127.0.0.1:8080", 3128),
|
||||
SocketAddressParts {
|
||||
host: "127.0.0.1",
|
||||
port: 8080,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_host_port_with_scheme_and_path() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://example.com:8080/some/path", 3128),
|
||||
SocketAddressParts {
|
||||
host: "example.com",
|
||||
port: 8080,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_strips_userinfo() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://user:pass@host.example:5555", 3128),
|
||||
SocketAddressParts {
|
||||
host: "host.example",
|
||||
port: 5555,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_ipv6_with_brackets() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://[::1]:9999", 3128),
|
||||
SocketAddressParts {
|
||||
host: "::1",
|
||||
port: 9999,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_does_not_treat_unbracketed_ipv6_as_host_port() {
|
||||
assert_eq!(
|
||||
parse_host_port("2001:db8::1", 3128),
|
||||
SocketAddressParts {
|
||||
host: "2001:db8::1",
|
||||
port: 3128,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_falls_back_to_default_port_when_port_is_invalid() {
|
||||
assert_eq!(
|
||||
parse_host_port("example.com:notaport", 3128),
|
||||
SocketAddressParts {
|
||||
host: "example.com:notaport",
|
||||
port: 3128,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_maps_localhost_to_loopback() {
|
||||
assert_eq!(
|
||||
resolve_addr("localhost", 3128),
|
||||
"127.0.0.1:3128".parse().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_parses_ip_literals() {
|
||||
assert_eq!(resolve_addr("1.2.3.4", 80), "1.2.3.4:80".parse().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_parses_ipv6_literals() {
|
||||
assert_eq!(
|
||||
resolve_addr("http://[::1]:8080", 3128),
|
||||
"[::1]:8080".parse().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_falls_back_to_loopback_for_hostnames() {
|
||||
assert_eq!(
|
||||
resolve_addr("http://example.com:5555", 3128),
|
||||
"127.0.0.1:5555".parse().unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,5 +26,76 @@ pub fn normalize_host(host: &str) -> String {
|
||||
{
|
||||
return host[1..end].to_ascii_lowercase();
|
||||
}
|
||||
host.split(':').next().unwrap_or("").to_ascii_lowercase()
|
||||
|
||||
// The proxy stack should typically hand us a host without a port, but be
|
||||
// defensive and strip `:port` when there is exactly one `:`.
|
||||
if host.bytes().filter(|b| *b == b':').count() == 1 {
|
||||
return host
|
||||
.split(':')
|
||||
.next()
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
}
|
||||
|
||||
// Avoid mangling unbracketed IPv6 literals.
|
||||
host.to_ascii_lowercase()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn method_allowed_full_allows_everything() {
|
||||
assert!(method_allowed(NetworkMode::Full, "GET"));
|
||||
assert!(method_allowed(NetworkMode::Full, "POST"));
|
||||
assert!(method_allowed(NetworkMode::Full, "CONNECT"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn method_allowed_limited_allows_only_safe_methods() {
|
||||
assert!(method_allowed(NetworkMode::Limited, "GET"));
|
||||
assert!(method_allowed(NetworkMode::Limited, "HEAD"));
|
||||
assert!(method_allowed(NetworkMode::Limited, "OPTIONS"));
|
||||
assert!(!method_allowed(NetworkMode::Limited, "POST"));
|
||||
assert!(!method_allowed(NetworkMode::Limited, "CONNECT"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_loopback_host_handles_localhost_variants() {
|
||||
assert!(is_loopback_host("localhost"));
|
||||
assert!(is_loopback_host("localhost."));
|
||||
assert!(is_loopback_host("LOCALHOST"));
|
||||
assert!(!is_loopback_host("notlocalhost"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_loopback_host_handles_ip_literals() {
|
||||
assert!(is_loopback_host("127.0.0.1"));
|
||||
assert!(is_loopback_host("::1"));
|
||||
assert!(!is_loopback_host("1.2.3.4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_lowercases_and_trims() {
|
||||
assert_eq!(normalize_host(" ExAmPlE.CoM "), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_port_for_host_port() {
|
||||
assert_eq!(normalize_host("example.com:1234"), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_preserves_unbracketed_ipv6() {
|
||||
assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_brackets_for_ipv6() {
|
||||
assert_eq!(normalize_host("[::1]"), "::1");
|
||||
assert_eq!(normalize_host("[::1]:443"), "::1");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,3 +568,264 @@ fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]
|
||||
fn unix_timestamp() -> i64 {
|
||||
OffsetDateTime::now_utc().unix_timestamp()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::config::NetworkPolicy;
|
||||
use crate::config::NetworkProxyConfig;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn app_state_for_policy(policy: NetworkPolicy) -> AppState {
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
mode: NetworkMode::Full,
|
||||
policy,
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
let allow_set = compile_globset(&config.network_proxy.policy.allowed_domains).unwrap();
|
||||
let deny_set = compile_globset(&config.network_proxy.policy.denied_domains).unwrap();
|
||||
|
||||
let state = ConfigState {
|
||||
config,
|
||||
mtime: None,
|
||||
allow_set,
|
||||
deny_set,
|
||||
mitm: None,
|
||||
cfg_path: PathBuf::from("/nonexistent/config.toml"),
|
||||
blocked: VecDeque::new(),
|
||||
};
|
||||
|
||||
AppState {
|
||||
state: Arc::new(RwLock::new(state)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_denied_wins_over_allowed() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
denied_domains: vec!["example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("example.com").await.unwrap(),
|
||||
(true, "denied".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_requires_allowlist_match() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("example.com").await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
assert_eq!(
|
||||
state.host_blocked("not-example.com").await.unwrap(),
|
||||
(true, "not_allowed".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_expands_apex_for_wildcard_patterns() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["*.openai.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("openai.com").await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
assert_eq!(
|
||||
state.host_blocked("api.openai.com").await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_loopback_when_local_binding_disabled() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("127.0.0.1").await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
state.host_blocked("localhost").await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_allows_loopback_when_explicitly_allowlisted_and_local_binding_disabled() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["localhost".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("localhost").await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_widening_allowed_domains() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["example.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string(), "evil.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_requires_managed_denied_domains_entries() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
denied_domains: Some(vec!["evil.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
denied_domains: vec![],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_enabling_when_managed_disabled() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
enabled: Some(false),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_allow_local_binding_when_managed_disabled() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allow_local_binding: Some(false),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allow_local_binding: true,
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_is_case_insensitive() {
|
||||
let patterns = vec!["ExAmPle.CoM".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("example.com"));
|
||||
assert!(set.is_match("EXAMPLE.COM"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_expands_apex_for_wildcard_patterns() {
|
||||
let patterns = vec!["*.openai.com".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("openai.com"));
|
||||
assert!(set.is_match("api.openai.com"));
|
||||
assert!(!set.is_match("evilopenai.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_dedupes_patterns_without_changing_behavior() {
|
||||
let patterns = vec!["example.com".to_string(), "example.com".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("example.com"));
|
||||
assert!(set.is_match("EXAMPLE.COM"));
|
||||
assert!(!set.is_match("not-example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_rejects_invalid_patterns() {
|
||||
let patterns = vec!["[".to_string()];
|
||||
assert!(compile_globset(&patterns).is_err());
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn unix_socket_allowlist_is_respected_on_macos() {
|
||||
let socket_path = "/tmp/example.sock".to_string();
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_unix_sockets: vec![socket_path.clone()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert!(state.is_unix_socket_allowed(&socket_path).await.unwrap());
|
||||
assert!(
|
||||
!state
|
||||
.is_unix_socket_allowed("/tmp/not-allowed.sock")
|
||||
.await
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
#[tokio::test]
|
||||
async fn unix_socket_allowlist_is_rejected_on_non_macos() {
|
||||
let socket_path = "/tmp/example.sock".to_string();
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_unix_sockets: vec![socket_path.clone()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert!(!state.is_unix_socket_allowed(&socket_path).await.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user