mirror of
https://github.com/openai/codex.git
synced 2026-04-24 22:54:54 +00:00
feat(network-proxy): add MITM support and gate limited-mode CONNECT (#9859)
## Description - Adds MITM support (CA load/issue, TLS termination, optional body inspection). - Adds `codex-network-proxy init` to create `CODEX_HOME/network_proxy/mitm`. - Enforces limited-mode HTTPS correctly: `CONNECT` requires MITM, otherwise blocked with `mitm_required`. - Keeps `origin/main` layering/reload semantics (managed layers included in reload checks). - Centralizes block reasons (`REASON_MITM_REQUIRED`) and removes `println!`. - Scope is MITM-only (no SOCKS changes). gated by `mitm=false` (default)
This commit is contained in:
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -2039,6 +2039,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"clap",
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-home-dir",
|
||||
"codex-utils-rustls-provider",
|
||||
"globset",
|
||||
"pretty_assertions",
|
||||
|
||||
@@ -16,6 +16,7 @@ anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-home-dir = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
globset = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
@@ -34,6 +34,9 @@ allow_upstream_proxy = true
|
||||
dangerously_allow_non_loopback_proxy = false
|
||||
dangerously_allow_non_loopback_admin = false
|
||||
mode = "full" # default when unset; use "limited" for read-only mode
|
||||
# When true, HTTPS CONNECT can be terminated so limited-mode method policy still applies.
|
||||
mitm = false
|
||||
# CA cert/key are managed internally under $CODEX_HOME/proxy/ (ca.pem + ca.key).
|
||||
|
||||
# Hosts must match the allowlist (unless denied).
|
||||
# If `allowed_domains` is empty, the proxy blocks requests until an allowlist is configured.
|
||||
@@ -85,8 +88,9 @@ When a request is blocked, the proxy responds with `403` and includes:
|
||||
- `blocked-by-method-policy`
|
||||
- `blocked-by-policy`
|
||||
|
||||
In "limited" mode, only `GET`, `HEAD`, and `OPTIONS` are allowed. HTTPS `CONNECT` and SOCKS5 are
|
||||
blocked because they would bypass method enforcement.
|
||||
In "limited" mode, only `GET`, `HEAD`, and `OPTIONS` are allowed. HTTPS `CONNECT` requests require
|
||||
MITM to enforce limited-mode method policy; otherwise they are blocked. SOCKS5 remains blocked in
|
||||
limited mode.
|
||||
|
||||
Websocket clients typically tunnel `wss://` through HTTPS `CONNECT`; those CONNECT targets still go
|
||||
through the same host allowlist/denylist checks.
|
||||
|
||||
344
codex-rs/network-proxy/src/certs.rs
Normal file
344
codex-rs/network-proxy/src/certs.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_utils_home_dir::find_codex_home;
|
||||
use rama_net::tls::ApplicationProtocol;
|
||||
use rama_tls_rustls::dep::pki_types::CertificateDer;
|
||||
use rama_tls_rustls::dep::pki_types::PrivateKeyDer;
|
||||
use rama_tls_rustls::dep::pki_types::pem::PemObject;
|
||||
use rama_tls_rustls::dep::rcgen::BasicConstraints;
|
||||
use rama_tls_rustls::dep::rcgen::CertificateParams;
|
||||
use rama_tls_rustls::dep::rcgen::DistinguishedName;
|
||||
use rama_tls_rustls::dep::rcgen::DnType;
|
||||
use rama_tls_rustls::dep::rcgen::ExtendedKeyUsagePurpose;
|
||||
use rama_tls_rustls::dep::rcgen::IsCa;
|
||||
use rama_tls_rustls::dep::rcgen::Issuer;
|
||||
use rama_tls_rustls::dep::rcgen::KeyPair;
|
||||
use rama_tls_rustls::dep::rcgen::KeyUsagePurpose;
|
||||
use rama_tls_rustls::dep::rcgen::PKCS_ECDSA_P256_SHA256;
|
||||
use rama_tls_rustls::dep::rcgen::SanType;
|
||||
use rama_tls_rustls::dep::rustls;
|
||||
use rama_tls_rustls::server::TlsAcceptorData;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::net::IpAddr;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
use tracing::info;
|
||||
|
||||
pub(super) struct ManagedMitmCa {
|
||||
issuer: Issuer<'static, KeyPair>,
|
||||
}
|
||||
|
||||
impl ManagedMitmCa {
|
||||
pub(super) fn load_or_create() -> Result<Self> {
|
||||
let (ca_cert_pem, ca_key_pem) = load_or_create_ca()?;
|
||||
let ca_key = KeyPair::from_pem(&ca_key_pem).context("failed to parse CA key")?;
|
||||
let issuer: Issuer<'static, KeyPair> =
|
||||
Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key).context("failed to parse CA cert")?;
|
||||
Ok(Self { issuer })
|
||||
}
|
||||
|
||||
pub(super) fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
|
||||
let (cert_pem, key_pem) = issue_host_certificate_pem(host, &self.issuer)?;
|
||||
let cert = CertificateDer::from_pem_slice(cert_pem.as_bytes())
|
||||
.context("failed to parse host cert PEM")?;
|
||||
let key = PrivateKeyDer::from_pem_slice(key_pem.as_bytes())
|
||||
.context("failed to parse host key PEM")?;
|
||||
let mut server_config =
|
||||
rustls::ServerConfig::builder_with_protocol_versions(rustls::ALL_VERSIONS)
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert], key)
|
||||
.context("failed to build rustls server config")?;
|
||||
server_config.alpn_protocols = vec![
|
||||
ApplicationProtocol::HTTP_2.as_bytes().to_vec(),
|
||||
ApplicationProtocol::HTTP_11.as_bytes().to_vec(),
|
||||
];
|
||||
|
||||
Ok(TlsAcceptorData::from(server_config))
|
||||
}
|
||||
}
|
||||
|
||||
fn issue_host_certificate_pem(
|
||||
host: &str,
|
||||
issuer: &Issuer<'_, KeyPair>,
|
||||
) -> Result<(String, String)> {
|
||||
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
let mut params = CertificateParams::new(Vec::new())
|
||||
.map_err(|err| anyhow!("failed to create cert params: {err}"))?;
|
||||
params.subject_alt_names.push(SanType::IpAddress(ip));
|
||||
params
|
||||
} else {
|
||||
CertificateParams::new(vec![host.to_string()])
|
||||
.map_err(|err| anyhow!("failed to create cert params: {err}"))?
|
||||
};
|
||||
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
|
||||
let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
|
||||
.map_err(|err| anyhow!("failed to generate host key pair: {err}"))?;
|
||||
let cert = params
|
||||
.signed_by(&key_pair, issuer)
|
||||
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
|
||||
|
||||
Ok((cert.pem(), key_pair.serialize_pem()))
|
||||
}
|
||||
|
||||
const MANAGED_MITM_CA_DIR: &str = "proxy";
|
||||
const MANAGED_MITM_CA_CERT: &str = "ca.pem";
|
||||
const MANAGED_MITM_CA_KEY: &str = "ca.key";
|
||||
|
||||
fn managed_ca_paths() -> Result<(PathBuf, PathBuf)> {
|
||||
let codex_home =
|
||||
find_codex_home().context("failed to resolve CODEX_HOME for managed MITM CA")?;
|
||||
let proxy_dir = codex_home.join(MANAGED_MITM_CA_DIR);
|
||||
Ok((
|
||||
proxy_dir.join(MANAGED_MITM_CA_CERT),
|
||||
proxy_dir.join(MANAGED_MITM_CA_KEY),
|
||||
))
|
||||
}
|
||||
|
||||
fn load_or_create_ca() -> Result<(String, String)> {
|
||||
let (cert_path, key_path) = managed_ca_paths()?;
|
||||
|
||||
if cert_path.exists() || key_path.exists() {
|
||||
if !cert_path.exists() || !key_path.exists() {
|
||||
return Err(anyhow!(
|
||||
"both managed MITM CA files must exist (cert={}, key={})",
|
||||
cert_path.display(),
|
||||
key_path.display()
|
||||
));
|
||||
}
|
||||
validate_existing_ca_key_file(&key_path)?;
|
||||
let cert_pem = fs::read_to_string(&cert_path)
|
||||
.with_context(|| format!("failed to read CA cert {}", cert_path.display()))?;
|
||||
let key_pem = fs::read_to_string(&key_path)
|
||||
.with_context(|| format!("failed to read CA key {}", key_path.display()))?;
|
||||
return Ok((cert_pem, key_pem));
|
||||
}
|
||||
|
||||
if let Some(parent) = cert_path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("failed to create {}", parent.display()))?;
|
||||
}
|
||||
if let Some(parent) = key_path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("failed to create {}", parent.display()))?;
|
||||
}
|
||||
|
||||
let (cert_pem, key_pem) = generate_ca()?;
|
||||
// The CA key is a high-value secret. Create it atomically with restrictive permissions.
|
||||
// The cert can be world-readable, but we still write it atomically to avoid partial writes.
|
||||
//
|
||||
// We intentionally use create-new semantics: if a key already exists, we should not overwrite
|
||||
// it silently (that would invalidate previously-trusted cert chains).
|
||||
write_atomic_create_new(&key_path, key_pem.as_bytes(), 0o600)
|
||||
.with_context(|| format!("failed to persist CA key {}", key_path.display()))?;
|
||||
if let Err(err) = write_atomic_create_new(&cert_path, cert_pem.as_bytes(), 0o644)
|
||||
.with_context(|| format!("failed to persist CA cert {}", cert_path.display()))
|
||||
{
|
||||
// Avoid leaving a partially-created CA around (cert missing) if the second write fails.
|
||||
let _ = fs::remove_file(&key_path);
|
||||
return Err(err);
|
||||
}
|
||||
let cert_path = cert_path.display();
|
||||
let key_path = key_path.display();
|
||||
info!("generated MITM CA (cert_path={cert_path}, key_path={key_path})");
|
||||
Ok((cert_pem, key_pem))
|
||||
}
|
||||
|
||||
fn generate_ca() -> Result<(String, String)> {
|
||||
let mut params = CertificateParams::default();
|
||||
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::KeyCertSign,
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
let mut dn = DistinguishedName::new();
|
||||
dn.push(DnType::CommonName, "network_proxy MITM CA");
|
||||
params.distinguished_name = dn;
|
||||
|
||||
let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
|
||||
.map_err(|err| anyhow!("failed to generate CA key pair: {err}"))?;
|
||||
let cert = params
|
||||
.self_signed(&key_pair)
|
||||
.map_err(|err| anyhow!("failed to generate CA cert: {err}"))?;
|
||||
Ok((cert.pem(), key_pair.serialize_pem()))
|
||||
}
|
||||
|
||||
fn write_atomic_create_new(path: &Path, contents: &[u8], mode: u32) -> Result<()> {
|
||||
let parent = path
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow!("missing parent directory"))?;
|
||||
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
let pid = std::process::id();
|
||||
let file_name = path.file_name().unwrap_or_default().to_string_lossy();
|
||||
let tmp_path = parent.join(format!(".{file_name}.tmp.{pid}.{nanos}"));
|
||||
|
||||
let mut file = open_create_new_with_mode(&tmp_path, mode)?;
|
||||
file.write_all(contents)
|
||||
.with_context(|| format!("failed to write {}", tmp_path.display()))?;
|
||||
file.sync_all()
|
||||
.with_context(|| format!("failed to fsync {}", tmp_path.display()))?;
|
||||
drop(file);
|
||||
|
||||
// Create the final file using "create-new" semantics (no overwrite). `rename` on Unix can
|
||||
// overwrite existing files, so prefer a hard-link, which fails if the destination exists.
|
||||
match fs::hard_link(&tmp_path, path) {
|
||||
Ok(()) => {
|
||||
fs::remove_file(&tmp_path)
|
||||
.with_context(|| format!("failed to remove {}", tmp_path.display()))?;
|
||||
}
|
||||
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
|
||||
let _ = fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"refusing to overwrite existing file {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
// Best-effort fallback for environments where hard links are not supported.
|
||||
// This is still subject to a TOCTOU race, but the typical case is a private per-user
|
||||
// config directory, where other users cannot create files anyway.
|
||||
if path.exists() {
|
||||
let _ = fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"refusing to overwrite existing file {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
fs::rename(&tmp_path, path).with_context(|| {
|
||||
format!(
|
||||
"failed to rename {} -> {}",
|
||||
tmp_path.display(),
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
// Best-effort durability: ensure the directory entry is persisted too.
|
||||
let dir = File::open(parent).with_context(|| format!("failed to open {}", parent.display()))?;
|
||||
dir.sync_all()
|
||||
.with_context(|| format!("failed to fsync {}", parent.display()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn validate_existing_ca_key_file(path: &Path) -> Result<()> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let metadata = fs::symlink_metadata(path)
|
||||
.with_context(|| format!("failed to stat CA key {}", path.display()))?;
|
||||
if metadata.file_type().is_symlink() {
|
||||
return Err(anyhow!(
|
||||
"refusing to use symlink for managed MITM CA key {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
if !metadata.is_file() {
|
||||
return Err(anyhow!(
|
||||
"managed MITM CA key is not a regular file: {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
let mode = metadata.permissions().mode() & 0o777;
|
||||
if mode & 0o077 != 0 {
|
||||
return Err(anyhow!(
|
||||
"managed MITM CA key {} must not be group/world accessible (mode={mode:o}; expected <= 600)",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn validate_existing_ca_key_file(_path: &Path) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn open_create_new_with_mode(path: &Path, mode: u32) -> Result<File> {
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
|
||||
OpenOptions::new()
|
||||
.write(true)
|
||||
.create_new(true)
|
||||
.mode(mode)
|
||||
.open(path)
|
||||
.with_context(|| format!("failed to create {}", path.display()))
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn open_create_new_with_mode(path: &Path, _mode: u32) -> Result<File> {
|
||||
OpenOptions::new()
|
||||
.write(true)
|
||||
.create_new(true)
|
||||
.open(path)
|
||||
.with_context(|| format!("failed to create {}", path.display()))
|
||||
}
|
||||
|
||||
#[cfg(all(test, unix))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn validate_existing_ca_key_file_rejects_group_world_permissions() {
|
||||
let dir = tempdir().unwrap();
|
||||
let key_path = dir.path().join("ca.key");
|
||||
fs::write(&key_path, "key").unwrap();
|
||||
fs::set_permissions(&key_path, fs::Permissions::from_mode(0o644)).unwrap();
|
||||
|
||||
let err = validate_existing_ca_key_file(&key_path).unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("group/world accessible"),
|
||||
"unexpected error: {err:#}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_existing_ca_key_file_rejects_symlink() {
|
||||
use std::os::unix::fs::symlink;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let target = dir.path().join("real.key");
|
||||
let link = dir.path().join("ca.key");
|
||||
fs::write(&target, "key").unwrap();
|
||||
symlink(&target, &link).unwrap();
|
||||
|
||||
let err = validate_existing_ca_key_file(&link).unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("symlink"),
|
||||
"unexpected error: {err:#}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_existing_ca_key_file_allows_private_permissions() {
|
||||
let dir = tempdir().unwrap();
|
||||
let key_path = dir.path().join("ca.key");
|
||||
fs::write(&key_path, "key").unwrap();
|
||||
fs::set_permissions(&key_path, fs::Permissions::from_mode(0o600)).unwrap();
|
||||
|
||||
validate_existing_ca_key_file(&key_path).unwrap();
|
||||
}
|
||||
}
|
||||
@@ -45,6 +45,8 @@ pub struct NetworkProxySettings {
|
||||
#[serde(default)]
|
||||
pub allow_unix_sockets: Vec<String>,
|
||||
pub allow_local_binding: bool,
|
||||
#[serde(default)]
|
||||
pub mitm: bool,
|
||||
}
|
||||
|
||||
impl Default for NetworkProxySettings {
|
||||
@@ -65,6 +67,7 @@ impl Default for NetworkProxySettings {
|
||||
denied_domains: Vec::new(),
|
||||
allow_unix_sockets: Vec::new(),
|
||||
allow_local_binding: true,
|
||||
mitm: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -74,6 +77,7 @@ impl Default for NetworkProxySettings {
|
||||
pub enum NetworkMode {
|
||||
/// Limited (read-only) access: only GET/HEAD/OPTIONS are allowed for HTTP. HTTPS CONNECT is
|
||||
/// blocked unless MITM is enabled so the proxy can enforce method policy on inner requests.
|
||||
/// SOCKS5 remains blocked in limited mode.
|
||||
Limited,
|
||||
/// Full network access: all HTTP methods are allowed, and HTTPS CONNECTs are tunneled without
|
||||
/// MITM interception.
|
||||
@@ -393,6 +397,7 @@ mod tests {
|
||||
denied_domains: Vec::new(),
|
||||
allow_unix_sockets: Vec::new(),
|
||||
allow_local_binding: true,
|
||||
mitm: false,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::mitm;
|
||||
use crate::network_policy::NetworkDecision;
|
||||
use crate::network_policy::NetworkDecisionSource;
|
||||
use crate::network_policy::NetworkPolicyDecider;
|
||||
@@ -9,6 +10,7 @@ use crate::network_policy::NetworkProtocol;
|
||||
use crate::network_policy::evaluate_host_policy;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_MITM_REQUIRED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_PROXY_DISABLED;
|
||||
use crate::responses::PolicyDecisionDetails;
|
||||
@@ -49,6 +51,7 @@ use rama_http_backend::server::HttpServer;
|
||||
use rama_http_backend::server::layer::upgrade::UpgradeLayer;
|
||||
use rama_http_backend::server::layer::upgrade::Upgraded;
|
||||
use rama_net::Protocol;
|
||||
use rama_net::address::HostWithOptPort;
|
||||
use rama_net::address::ProxyAddress;
|
||||
use rama_net::client::ConnectorService;
|
||||
use rama_net::client::EstablishedClientConnection;
|
||||
@@ -233,10 +236,20 @@ async fn http_connect_accept(
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to read network mode", err))?;
|
||||
|
||||
if mode == NetworkMode::Limited {
|
||||
let mitm_state = match app_state.mitm_state().await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
error!("failed to load MITM state: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
|
||||
if mode == NetworkMode::Limited && mitm_state.is_none() {
|
||||
// Limited mode is designed to be read-only. Without MITM, a CONNECT tunnel would hide the
|
||||
// inner HTTP method/headers from the proxy, effectively bypassing method policy.
|
||||
let details = PolicyDecisionDetails {
|
||||
decision: NetworkPolicyDecision::Deny,
|
||||
reason: REASON_METHOD_NOT_ALLOWED,
|
||||
reason: REASON_MITM_REQUIRED,
|
||||
source: NetworkDecisionSource::ModeGuard,
|
||||
protocol: NetworkProtocol::HttpsConnect,
|
||||
host: &host,
|
||||
@@ -245,7 +258,7 @@ async fn http_connect_accept(
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(BlockedRequestArgs {
|
||||
host: host.clone(),
|
||||
reason: REASON_METHOD_NOT_ALLOWED.to_string(),
|
||||
reason: REASON_MITM_REQUIRED.to_string(),
|
||||
client: client.clone(),
|
||||
method: Some("CONNECT".to_string()),
|
||||
mode: Some(NetworkMode::Limited),
|
||||
@@ -256,15 +269,17 @@ async fn http_connect_accept(
|
||||
}))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("CONNECT blocked by method policy (client={client}, host={host}, mode=limited)");
|
||||
return Err(blocked_text_with_details(
|
||||
REASON_METHOD_NOT_ALLOWED,
|
||||
&details,
|
||||
));
|
||||
warn!(
|
||||
"CONNECT blocked; MITM required for read-only HTTPS in limited mode (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Err(blocked_text_with_details(REASON_MITM_REQUIRED, &details));
|
||||
}
|
||||
|
||||
req.extensions_mut().insert(ProxyTarget(authority));
|
||||
req.extensions_mut().insert(mode);
|
||||
if let Some(mitm_state) = mitm_state {
|
||||
req.extensions_mut().insert(mitm_state);
|
||||
}
|
||||
|
||||
Ok((
|
||||
Response::builder()
|
||||
@@ -276,9 +291,34 @@ async fn http_connect_accept(
|
||||
}
|
||||
|
||||
async fn http_connect_proxy(upgraded: Upgraded) -> Result<(), Infallible> {
|
||||
if upgraded.extensions().get::<ProxyTarget>().is_none() {
|
||||
let mode = upgraded
|
||||
.extensions()
|
||||
.get::<NetworkMode>()
|
||||
.copied()
|
||||
.unwrap_or(NetworkMode::Full);
|
||||
|
||||
let Some(target) = upgraded
|
||||
.extensions()
|
||||
.get::<ProxyTarget>()
|
||||
.map(|t| t.0.clone())
|
||||
else {
|
||||
warn!("CONNECT missing proxy target");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if mode == NetworkMode::Limited
|
||||
&& upgraded
|
||||
.extensions()
|
||||
.get::<Arc<mitm::MitmState>>()
|
||||
.is_some()
|
||||
{
|
||||
let host = normalize_host(&target.host.to_string());
|
||||
let port = target.port;
|
||||
info!("CONNECT MITM enabled (host={host}, port={port}, mode={mode:?})");
|
||||
if let Err(err) = mitm::mitm_tunnel(upgraded).await {
|
||||
warn!("MITM tunnel error: {err}");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let allow_upstream_proxy = match upgraded
|
||||
@@ -465,6 +505,10 @@ async fn http_plain_proxy(
|
||||
};
|
||||
let host = normalize_host(&authority.host.to_string());
|
||||
let port = authority.port;
|
||||
if let Err(err) = validate_plain_http_host_header(&req, &authority) {
|
||||
warn!("HTTP request host mismatch: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "host mismatch"));
|
||||
}
|
||||
let enabled = match app_state
|
||||
.enabled()
|
||||
.await
|
||||
@@ -669,6 +713,45 @@ fn remove_hop_by_hop_request_headers(headers: &mut HeaderMap) {
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_plain_http_host_header(
|
||||
req: &Request,
|
||||
target: &rama_net::address::HostWithPort,
|
||||
) -> std::result::Result<(), &'static str> {
|
||||
// Only enforce this in absolute-form requests. Origin-form requests use the Host header as the
|
||||
// routing authority, so there is no separate target authority to compare against.
|
||||
if req.uri().authority().is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(raw_host) = req.headers().get(header::HOST) else {
|
||||
return Ok(());
|
||||
};
|
||||
let raw_host = raw_host.to_str().map_err(|_| "invalid Host header")?;
|
||||
let parsed = HostWithOptPort::try_from(raw_host).map_err(|_| "invalid Host header")?;
|
||||
|
||||
let target_host = normalize_host(&target.host.to_string());
|
||||
let request_host = normalize_host(&parsed.host.to_string());
|
||||
if request_host.is_empty() || request_host != target_host {
|
||||
return Err("request Host header host does not match target authority");
|
||||
}
|
||||
|
||||
let expected_port = target.port;
|
||||
let request_port = match parsed.port {
|
||||
Some(port) => port,
|
||||
None => match req.uri().scheme_str() {
|
||||
Some("http") => 80,
|
||||
Some("https") => 443,
|
||||
Some(_) | None => expected_port,
|
||||
},
|
||||
};
|
||||
|
||||
if request_port != expected_port {
|
||||
return Err("request Host header port does not match target authority");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn json_blocked(host: &str, reason: &str, details: Option<&PolicyDecisionDetails<'_>>) -> Response {
|
||||
let (message, decision, source, protocol, port) = details
|
||||
.map(|details| {
|
||||
@@ -804,7 +887,7 @@ mod tests {
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
assert_eq!(
|
||||
response.headers().get("x-proxy-error").unwrap(),
|
||||
"blocked-by-method-policy"
|
||||
"blocked-by-mitm-required"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -853,6 +936,23 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_plain_proxy_rejects_absolute_uri_host_header_mismatch() {
|
||||
let state = Arc::new(network_proxy_state_for_policy(
|
||||
NetworkProxySettings::default(),
|
||||
));
|
||||
let mut req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("http://raw.githubusercontent.com/openai/codex/main/README.md")
|
||||
.header(header::HOST, "api.github.com")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
req.extensions_mut().insert(state);
|
||||
|
||||
let response = http_plain_proxy(None, req).await;
|
||||
assert_eq!(response.unwrap().status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_hop_by_hop_request_headers_keeps_forwarding_headers() {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
mod admin;
|
||||
mod certs;
|
||||
mod config;
|
||||
mod http_proxy;
|
||||
mod mitm;
|
||||
mod network_policy;
|
||||
mod policy;
|
||||
mod proxy;
|
||||
|
||||
482
codex-rs/network-proxy/src/mitm.rs
Normal file
482
codex-rs/network-proxy/src/mitm.rs
Normal file
@@ -0,0 +1,482 @@
|
||||
use crate::certs::ManagedMitmCa;
|
||||
use crate::config::NetworkMode;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
|
||||
use crate::responses::blocked_text_response;
|
||||
use crate::responses::text_response;
|
||||
use crate::runtime::HostBlockDecision;
|
||||
use crate::runtime::HostBlockReason;
|
||||
use crate::state::BlockedRequest;
|
||||
use crate::state::BlockedRequestArgs;
|
||||
use crate::state::NetworkProxyState;
|
||||
use crate::upstream::UpstreamClient;
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use rama_core::Layer;
|
||||
use rama_core::Service;
|
||||
use rama_core::bytes::Bytes;
|
||||
use rama_core::error::BoxError;
|
||||
use rama_core::extensions::ExtensionsRef;
|
||||
use rama_core::futures::stream::Stream;
|
||||
use rama_core::rt::Executor;
|
||||
use rama_core::service::service_fn;
|
||||
use rama_http::Body;
|
||||
use rama_http::BodyDataStream;
|
||||
use rama_http::HeaderValue;
|
||||
use rama_http::Request;
|
||||
use rama_http::Response;
|
||||
use rama_http::StatusCode;
|
||||
use rama_http::Uri;
|
||||
use rama_http::header::HOST;
|
||||
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
use rama_http_backend::server::HttpServer;
|
||||
use rama_http_backend::server::layer::upgrade::Upgraded;
|
||||
use rama_net::proxy::ProxyTarget;
|
||||
use rama_net::stream::SocketInfo;
|
||||
use rama_tls_rustls::server::TlsAcceptorData;
|
||||
use rama_tls_rustls::server::TlsAcceptorLayer;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Context as TaskContext;
|
||||
use std::task::Poll;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
/// State needed to terminate a CONNECT tunnel and enforce policy on inner HTTPS requests.
|
||||
pub struct MitmState {
|
||||
ca: ManagedMitmCa,
|
||||
upstream: UpstreamClient,
|
||||
inspect: bool,
|
||||
max_body_bytes: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MitmPolicyContext {
|
||||
target_host: String,
|
||||
target_port: u16,
|
||||
mode: NetworkMode,
|
||||
app_state: Arc<NetworkProxyState>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MitmRequestContext {
|
||||
policy: MitmPolicyContext,
|
||||
mitm: Arc<MitmState>,
|
||||
}
|
||||
|
||||
const MITM_INSPECT_BODIES: bool = false;
|
||||
const MITM_MAX_BODY_BYTES: usize = 4096;
|
||||
|
||||
impl std::fmt::Debug for MitmState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Avoid dumping internal state (CA material, connectors, etc.) to logs.
|
||||
f.debug_struct("MitmState")
|
||||
.field("inspect", &self.inspect)
|
||||
.field("max_body_bytes", &self.max_body_bytes)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl MitmState {
|
||||
pub(crate) fn new(allow_upstream_proxy: bool) -> Result<Self> {
|
||||
// MITM exists to make limited-mode HTTPS enforceable: once CONNECT is established, plain
|
||||
// proxying would lose visibility into the inner HTTP request. We generate/load a local CA
|
||||
// and issue per-host leaf certs so we can terminate TLS and apply policy.
|
||||
let ca = ManagedMitmCa::load_or_create()?;
|
||||
|
||||
let upstream = if allow_upstream_proxy {
|
||||
UpstreamClient::from_env_proxy()
|
||||
} else {
|
||||
UpstreamClient::direct()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
ca,
|
||||
upstream,
|
||||
inspect: MITM_INSPECT_BODIES,
|
||||
max_body_bytes: MITM_MAX_BODY_BYTES,
|
||||
})
|
||||
}
|
||||
|
||||
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
|
||||
self.ca.tls_acceptor_data_for_host(host)
|
||||
}
|
||||
|
||||
pub(crate) fn inspect_enabled(&self) -> bool {
|
||||
self.inspect
|
||||
}
|
||||
|
||||
pub(crate) fn max_body_bytes(&self) -> usize {
|
||||
self.max_body_bytes
|
||||
}
|
||||
}
|
||||
|
||||
/// Terminate the upgraded CONNECT stream with a generated leaf cert and proxy inner HTTPS traffic.
|
||||
pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
|
||||
let mitm = upgraded
|
||||
.extensions()
|
||||
.get::<Arc<MitmState>>()
|
||||
.cloned()
|
||||
.context("missing MITM state")?;
|
||||
let app_state = upgraded
|
||||
.extensions()
|
||||
.get::<Arc<NetworkProxyState>>()
|
||||
.cloned()
|
||||
.context("missing app state")?;
|
||||
let target = upgraded
|
||||
.extensions()
|
||||
.get::<ProxyTarget>()
|
||||
.context("missing proxy target")?
|
||||
.0
|
||||
.clone();
|
||||
let target_host = normalize_host(&target.host.to_string());
|
||||
let target_port = target.port;
|
||||
let acceptor_data = mitm.tls_acceptor_data_for_host(&target_host)?;
|
||||
let mode = upgraded
|
||||
.extensions()
|
||||
.get::<NetworkMode>()
|
||||
.copied()
|
||||
.unwrap_or(NetworkMode::Full);
|
||||
let request_ctx = Arc::new(MitmRequestContext {
|
||||
policy: MitmPolicyContext {
|
||||
target_host,
|
||||
target_port,
|
||||
mode,
|
||||
app_state,
|
||||
},
|
||||
mitm,
|
||||
});
|
||||
|
||||
let executor = upgraded
|
||||
.extensions()
|
||||
.get::<Executor>()
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
let http_service = HttpServer::auto(executor).service(
|
||||
(
|
||||
RemoveResponseHeaderLayer::hop_by_hop(),
|
||||
RemoveRequestHeaderLayer::hop_by_hop(),
|
||||
)
|
||||
.into_layer(service_fn({
|
||||
let request_ctx = request_ctx.clone();
|
||||
move |req| {
|
||||
let request_ctx = request_ctx.clone();
|
||||
async move { handle_mitm_request(req, request_ctx).await }
|
||||
}
|
||||
})),
|
||||
);
|
||||
|
||||
let https_service = TlsAcceptorLayer::new(acceptor_data)
|
||||
.with_store_client_hello(true)
|
||||
.into_layer(http_service);
|
||||
|
||||
https_service
|
||||
.serve(upgraded)
|
||||
.await
|
||||
.map_err(|err| anyhow!("MITM serve error: {err}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_mitm_request(
|
||||
req: Request,
|
||||
request_ctx: Arc<MitmRequestContext>,
|
||||
) -> Result<Response, std::convert::Infallible> {
|
||||
let response = match forward_request(req, &request_ctx).await {
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
warn!("MITM request handling failed: {err}");
|
||||
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
|
||||
}
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn forward_request(req: Request, request_ctx: &MitmRequestContext) -> Result<Response> {
|
||||
if let Some(response) = mitm_blocking_response(&req, &request_ctx.policy).await? {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
let target_host = request_ctx.policy.target_host.clone();
|
||||
let target_port = request_ctx.policy.target_port;
|
||||
let mitm = request_ctx.mitm.clone();
|
||||
|
||||
let method = req.method().as_str().to_string();
|
||||
let path = path_and_query(req.uri());
|
||||
let log_path = path_for_log(req.uri());
|
||||
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let authority = authority_header_value(&target_host, target_port);
|
||||
parts.uri = build_https_uri(&authority, &path)?;
|
||||
parts
|
||||
.headers
|
||||
.insert(HOST, HeaderValue::from_str(&authority)?);
|
||||
|
||||
let inspect = mitm.inspect_enabled();
|
||||
let max_body_bytes = mitm.max_body_bytes();
|
||||
let body = if inspect {
|
||||
inspect_body(
|
||||
body,
|
||||
max_body_bytes,
|
||||
RequestLogContext {
|
||||
host: authority.clone(),
|
||||
method: method.clone(),
|
||||
path: log_path.clone(),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
body
|
||||
};
|
||||
|
||||
let upstream_req = Request::from_parts(parts, body);
|
||||
let upstream_resp = mitm.upstream.serve(upstream_req).await?;
|
||||
respond_with_inspection(
|
||||
upstream_resp,
|
||||
inspect,
|
||||
max_body_bytes,
|
||||
&method,
|
||||
&log_path,
|
||||
&authority,
|
||||
)
|
||||
}
|
||||
|
||||
async fn mitm_blocking_response(
|
||||
req: &Request,
|
||||
policy: &MitmPolicyContext,
|
||||
) -> Result<Option<Response>> {
|
||||
if req.method().as_str() == "CONNECT" {
|
||||
return Ok(Some(text_response(
|
||||
StatusCode::METHOD_NOT_ALLOWED,
|
||||
"CONNECT not supported inside MITM",
|
||||
)));
|
||||
}
|
||||
|
||||
let method = req.method().as_str().to_string();
|
||||
let log_path = path_for_log(req.uri());
|
||||
let client = req
|
||||
.extensions()
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string());
|
||||
|
||||
if let Some(request_host) = extract_request_host(req) {
|
||||
let normalized = normalize_host(&request_host);
|
||||
if !normalized.is_empty() && normalized != policy.target_host {
|
||||
warn!(
|
||||
"MITM host mismatch (target={}, request_host={normalized})",
|
||||
policy.target_host
|
||||
);
|
||||
return Ok(Some(text_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"host mismatch",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// CONNECT already handled allowlist/denylist + decider policy. Re-check local/private
|
||||
// resolution here to defend against DNS rebinding between CONNECT and inner HTTPS requests.
|
||||
if matches!(
|
||||
policy
|
||||
.app_state
|
||||
.host_blocked(&policy.target_host, policy.target_port)
|
||||
.await?,
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
|
||||
) {
|
||||
let reason = HostBlockReason::NotAllowedLocal.as_str();
|
||||
let _ = policy
|
||||
.app_state
|
||||
.record_blocked(BlockedRequest::new(BlockedRequestArgs {
|
||||
host: policy.target_host.clone(),
|
||||
reason: reason.to_string(),
|
||||
client: client.clone(),
|
||||
method: Some(method.clone()),
|
||||
mode: Some(policy.mode),
|
||||
protocol: "https".to_string(),
|
||||
decision: None,
|
||||
source: None,
|
||||
port: Some(policy.target_port),
|
||||
}))
|
||||
.await;
|
||||
warn!(
|
||||
"MITM blocked local/private target after CONNECT (host={}, port={}, method={method}, path={log_path})",
|
||||
policy.target_host, policy.target_port
|
||||
);
|
||||
return Ok(Some(blocked_text_response(reason)));
|
||||
}
|
||||
|
||||
if !policy.mode.allows_method(&method) {
|
||||
let _ = policy
|
||||
.app_state
|
||||
.record_blocked(BlockedRequest::new(BlockedRequestArgs {
|
||||
host: policy.target_host.clone(),
|
||||
reason: REASON_METHOD_NOT_ALLOWED.to_string(),
|
||||
client: client.clone(),
|
||||
method: Some(method.clone()),
|
||||
mode: Some(policy.mode),
|
||||
protocol: "https".to_string(),
|
||||
decision: None,
|
||||
source: None,
|
||||
port: Some(policy.target_port),
|
||||
}))
|
||||
.await;
|
||||
warn!(
|
||||
"MITM blocked by method policy (host={}, method={method}, path={log_path}, mode={:?}, allowed_methods=GET, HEAD, OPTIONS)",
|
||||
policy.target_host, policy.mode
|
||||
);
|
||||
return Ok(Some(blocked_text_response(REASON_METHOD_NOT_ALLOWED)));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn respond_with_inspection(
|
||||
resp: Response,
|
||||
inspect: bool,
|
||||
max_body_bytes: usize,
|
||||
method: &str,
|
||||
log_path: &str,
|
||||
authority: &str,
|
||||
) -> Result<Response> {
|
||||
if !inspect {
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
let (parts, body) = resp.into_parts();
|
||||
let body = inspect_body(
|
||||
body,
|
||||
max_body_bytes,
|
||||
ResponseLogContext {
|
||||
host: authority.to_string(),
|
||||
method: method.to_string(),
|
||||
path: log_path.to_string(),
|
||||
status: parts.status,
|
||||
},
|
||||
);
|
||||
Ok(Response::from_parts(parts, body))
|
||||
}
|
||||
|
||||
fn inspect_body<T: BodyLoggable + Send + 'static>(
|
||||
body: Body,
|
||||
max_body_bytes: usize,
|
||||
ctx: T,
|
||||
) -> Body {
|
||||
Body::from_stream(InspectStream {
|
||||
inner: Box::pin(body.into_data_stream()),
|
||||
ctx: Some(Box::new(ctx)),
|
||||
len: 0,
|
||||
max_body_bytes,
|
||||
})
|
||||
}
|
||||
|
||||
struct InspectStream<T> {
|
||||
inner: Pin<Box<BodyDataStream>>,
|
||||
ctx: Option<Box<T>>,
|
||||
len: usize,
|
||||
max_body_bytes: usize,
|
||||
}
|
||||
|
||||
impl<T: BodyLoggable> Stream for InspectStream<T> {
|
||||
type Item = Result<Bytes, BoxError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
match this.inner.as_mut().poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(bytes))) => {
|
||||
this.len = this.len.saturating_add(bytes.len());
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(None) => {
|
||||
if let Some(ctx) = this.ctx.take() {
|
||||
ctx.log(this.len, this.len > this.max_body_bytes);
|
||||
}
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RequestLogContext {
|
||||
host: String,
|
||||
method: String,
|
||||
path: String,
|
||||
}
|
||||
|
||||
struct ResponseLogContext {
|
||||
host: String,
|
||||
method: String,
|
||||
path: String,
|
||||
status: StatusCode,
|
||||
}
|
||||
|
||||
trait BodyLoggable {
|
||||
fn log(self, len: usize, truncated: bool);
|
||||
}
|
||||
|
||||
impl BodyLoggable for RequestLogContext {
|
||||
fn log(self, len: usize, truncated: bool) {
|
||||
let host = self.host;
|
||||
let method = self.method;
|
||||
let path = self.path;
|
||||
info!(
|
||||
"MITM inspected request body (host={host}, method={method}, path={path}, body_len={len}, truncated={truncated})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl BodyLoggable for ResponseLogContext {
|
||||
fn log(self, len: usize, truncated: bool) {
|
||||
let host = self.host;
|
||||
let method = self.method;
|
||||
let path = self.path;
|
||||
let status = self.status;
|
||||
info!(
|
||||
"MITM inspected response body (host={host}, method={method}, path={path}, status={status}, body_len={len}, truncated={truncated})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_request_host(req: &Request) -> Option<String> {
|
||||
req.headers()
|
||||
.get(HOST)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(ToString::to_string)
|
||||
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
|
||||
}
|
||||
|
||||
fn authority_header_value(host: &str, port: u16) -> String {
|
||||
// Host header / URI authority formatting.
|
||||
if host.contains(':') {
|
||||
if port == 443 {
|
||||
format!("[{host}]")
|
||||
} else {
|
||||
format!("[{host}]:{port}")
|
||||
}
|
||||
} else if port == 443 {
|
||||
host.to_string()
|
||||
} else {
|
||||
format!("{host}:{port}")
|
||||
}
|
||||
}
|
||||
|
||||
fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
|
||||
let target = format!("https://{authority}{path}");
|
||||
Ok(target.parse()?)
|
||||
}
|
||||
|
||||
fn path_and_query(uri: &Uri) -> String {
|
||||
uri.path_and_query()
|
||||
.map(rama_http::uri::PathAndQuery::as_str)
|
||||
.unwrap_or("/")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn path_for_log(uri: &Uri) -> String {
|
||||
uri.path().to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "mitm_tests.rs"]
|
||||
mod tests;
|
||||
110
codex-rs/network-proxy/src/mitm_tests.rs
Normal file
110
codex-rs/network-proxy/src/mitm_tests.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use super::*;
|
||||
|
||||
use crate::config::NetworkProxySettings;
|
||||
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
|
||||
use crate::runtime::network_proxy_state_for_policy;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rama_http::Body;
|
||||
use rama_http::Method;
|
||||
use rama_http::Request;
|
||||
use rama_http::StatusCode;
|
||||
|
||||
fn policy_ctx(
|
||||
app_state: Arc<NetworkProxyState>,
|
||||
mode: NetworkMode,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
) -> MitmPolicyContext {
|
||||
MitmPolicyContext {
|
||||
target_host: target_host.to_string(),
|
||||
target_port,
|
||||
mode,
|
||||
app_state,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mitm_policy_blocks_disallowed_method_and_records_telemetry() {
|
||||
let app_state = Arc::new(network_proxy_state_for_policy(NetworkProxySettings {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
..NetworkProxySettings::default()
|
||||
}));
|
||||
let ctx = policy_ctx(app_state.clone(), NetworkMode::Limited, "example.com", 443);
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri("/v1/responses?api_key=secret")
|
||||
.header(HOST, "example.com")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = mitm_blocking_response(&req, &ctx)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("POST should be blocked in limited mode");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
assert_eq!(
|
||||
response.headers().get("x-proxy-error").unwrap(),
|
||||
"blocked-by-method-policy"
|
||||
);
|
||||
|
||||
let blocked = app_state.drain_blocked().await.unwrap();
|
||||
assert_eq!(blocked.len(), 1);
|
||||
assert_eq!(blocked[0].reason, REASON_METHOD_NOT_ALLOWED);
|
||||
assert_eq!(blocked[0].method.as_deref(), Some("POST"));
|
||||
assert_eq!(blocked[0].host, "example.com");
|
||||
assert_eq!(blocked[0].port, Some(443));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mitm_policy_rejects_host_mismatch() {
|
||||
let app_state = Arc::new(network_proxy_state_for_policy(NetworkProxySettings {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
..NetworkProxySettings::default()
|
||||
}));
|
||||
let ctx = policy_ctx(app_state.clone(), NetworkMode::Full, "example.com", 443);
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/")
|
||||
.header(HOST, "evil.example")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = mitm_blocking_response(&req, &ctx)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("mismatched host should be rejected");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
assert_eq!(app_state.blocked_snapshot().await.unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mitm_policy_rechecks_local_private_target_after_connect() {
|
||||
let app_state = Arc::new(network_proxy_state_for_policy(NetworkProxySettings {
|
||||
allowed_domains: vec!["*".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkProxySettings::default()
|
||||
}));
|
||||
let ctx = policy_ctx(app_state.clone(), NetworkMode::Full, "10.0.0.1", 443);
|
||||
let req = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/health?token=secret")
|
||||
.header(HOST, "10.0.0.1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = mitm_blocking_response(&req, &ctx)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("local/private target should be blocked on inner request");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
|
||||
let blocked = app_state.drain_blocked().await.unwrap();
|
||||
assert_eq!(blocked.len(), 1);
|
||||
assert_eq!(blocked[0].reason, REASON_NOT_ALLOWED_LOCAL);
|
||||
assert_eq!(blocked[0].host, "10.0.0.1");
|
||||
assert_eq!(blocked[0].port, Some(443));
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
pub(crate) const REASON_DENIED: &str = "denied";
|
||||
pub(crate) const REASON_METHOD_NOT_ALLOWED: &str = "method_not_allowed";
|
||||
pub(crate) const REASON_MITM_REQUIRED: &str = "mitm_required";
|
||||
pub(crate) const REASON_NOT_ALLOWED: &str = "not_allowed";
|
||||
pub(crate) const REASON_NOT_ALLOWED_LOCAL: &str = "not_allowed_local";
|
||||
pub(crate) const REASON_POLICY_DENIED: &str = "policy_denied";
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::network_policy::NetworkPolicyDecision;
|
||||
use crate::network_policy::NetworkProtocol;
|
||||
use crate::reasons::REASON_DENIED;
|
||||
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_MITM_REQUIRED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
|
||||
use rama_http::Body;
|
||||
@@ -51,6 +52,7 @@ pub fn blocked_header_value(reason: &str) -> &'static str {
|
||||
REASON_NOT_ALLOWED | REASON_NOT_ALLOWED_LOCAL => "blocked-by-allowlist",
|
||||
REASON_DENIED => "blocked-by-denylist",
|
||||
REASON_METHOD_NOT_ALLOWED => "blocked-by-method-policy",
|
||||
REASON_MITM_REQUIRED => "blocked-by-mitm-required",
|
||||
_ => "blocked-by-policy",
|
||||
}
|
||||
}
|
||||
@@ -67,10 +69,19 @@ pub fn blocked_message(reason: &str) -> &'static str {
|
||||
REASON_METHOD_NOT_ALLOWED => {
|
||||
"Codex blocked this request: method not allowed in limited mode."
|
||||
}
|
||||
REASON_MITM_REQUIRED => "Codex blocked this request: MITM required for limited HTTPS.",
|
||||
_ => "Codex blocked this request by network policy.",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn blocked_text_response(reason: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "text/plain")
|
||||
.header("x-proxy-error", blocked_header_value(reason))
|
||||
.body(Body::from(blocked_message(reason)))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
|
||||
}
|
||||
pub fn blocked_message_with_policy(reason: &str, details: &PolicyDecisionDetails<'_>) -> String {
|
||||
let _ = (details.reason, details.host);
|
||||
blocked_message(reason).to_string()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::config::NetworkProxyConfig;
|
||||
use crate::config::ValidatedUnixSocketPath;
|
||||
use crate::mitm::MitmState;
|
||||
use crate::policy::Host;
|
||||
use crate::policy::is_loopback_host;
|
||||
use crate::policy::is_non_public_ip;
|
||||
@@ -141,6 +142,7 @@ pub struct ConfigState {
|
||||
pub config: NetworkProxyConfig,
|
||||
pub allow_set: GlobSet,
|
||||
pub deny_set: GlobSet,
|
||||
pub mitm: Option<Arc<MitmState>>,
|
||||
pub constraints: NetworkProxyConstraints,
|
||||
pub blocked: VecDeque<BlockedRequest>,
|
||||
pub blocked_total: u64,
|
||||
@@ -499,6 +501,12 @@ impl NetworkProxyState {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mitm_state(&self) -> Result<Option<Arc<MitmState>>> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.mitm.clone())
|
||||
}
|
||||
|
||||
pub async fn add_allowed_domain(&self, host: &str) -> Result<()> {
|
||||
self.update_domain_list(host, DomainListKind::Allow).await
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::config::NetworkProxyConfig;
|
||||
use crate::mitm::MitmState;
|
||||
use crate::policy::DomainPattern;
|
||||
use crate::policy::compile_globset;
|
||||
use crate::runtime::ConfigState;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use crate::runtime::BlockedRequest;
|
||||
pub use crate::runtime::BlockedRequestArgs;
|
||||
@@ -57,10 +59,18 @@ pub fn build_config_state(
|
||||
crate::config::validate_unix_socket_allowlist_paths(&config)?;
|
||||
let deny_set = compile_globset(&config.network.denied_domains)?;
|
||||
let allow_set = compile_globset(&config.network.allowed_domains)?;
|
||||
let mitm = if config.network.mitm {
|
||||
Some(Arc::new(MitmState::new(
|
||||
config.network.allow_upstream_proxy,
|
||||
)?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(ConfigState {
|
||||
config,
|
||||
allow_set,
|
||||
deny_set,
|
||||
mitm,
|
||||
constraints,
|
||||
blocked: std::collections::VecDeque::new(),
|
||||
blocked_total: 0,
|
||||
|
||||
Reference in New Issue
Block a user