feat(network-proxy): add MITM support and gate limited-mode CONNECT (#9859)

## Description
- Adds MITM support (CA load/issue, TLS termination, optional body
inspection).
- Adds `codex-network-proxy init` to create
`CODEX_HOME/network_proxy/mitm`.
- Enforces limited-mode HTTPS correctly: `CONNECT` requires MITM,
otherwise blocked with `mitm_required`.
- Keeps `origin/main` layering/reload semantics (managed layers included
in reload checks).
- Centralizes block reasons (`REASON_MITM_REQUIRED`) and removes
`println!`.
- Scope is MITM-only (no SOCKS changes).

gated by `mitm=false` (default)
This commit is contained in:
viyatb-oai
2026-02-24 10:15:15 -08:00
committed by GitHub
parent ca556fa313
commit 8d3d58f992
13 changed files with 1091 additions and 12 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -2039,6 +2039,7 @@ dependencies = [
"async-trait",
"clap",
"codex-utils-absolute-path",
"codex-utils-home-dir",
"codex-utils-rustls-provider",
"globset",
"pretty_assertions",

View File

@@ -16,6 +16,7 @@ anyhow = { workspace = true }
async-trait = { workspace = true }
clap = { workspace = true, features = ["derive"] }
codex-utils-absolute-path = { workspace = true }
codex-utils-home-dir = { workspace = true }
codex-utils-rustls-provider = { workspace = true }
globset = { workspace = true }
serde = { workspace = true, features = ["derive"] }

View File

@@ -34,6 +34,9 @@ allow_upstream_proxy = true
dangerously_allow_non_loopback_proxy = false
dangerously_allow_non_loopback_admin = false
mode = "full" # default when unset; use "limited" for read-only mode
# When true, HTTPS CONNECT can be terminated so limited-mode method policy still applies.
mitm = false
# CA cert/key are managed internally under $CODEX_HOME/proxy/ (ca.pem + ca.key).
# Hosts must match the allowlist (unless denied).
# If `allowed_domains` is empty, the proxy blocks requests until an allowlist is configured.
@@ -85,8 +88,9 @@ When a request is blocked, the proxy responds with `403` and includes:
- `blocked-by-method-policy`
- `blocked-by-policy`
In "limited" mode, only `GET`, `HEAD`, and `OPTIONS` are allowed. HTTPS `CONNECT` and SOCKS5 are
blocked because they would bypass method enforcement.
In "limited" mode, only `GET`, `HEAD`, and `OPTIONS` are allowed. HTTPS `CONNECT` requests require
MITM to enforce limited-mode method policy; otherwise they are blocked. SOCKS5 remains blocked in
limited mode.
Websocket clients typically tunnel `wss://` through HTTPS `CONNECT`; those CONNECT targets still go
through the same host allowlist/denylist checks.

View File

@@ -0,0 +1,344 @@
use anyhow::Context as _;
use anyhow::Result;
use anyhow::anyhow;
use codex_utils_home_dir::find_codex_home;
use rama_net::tls::ApplicationProtocol;
use rama_tls_rustls::dep::pki_types::CertificateDer;
use rama_tls_rustls::dep::pki_types::PrivateKeyDer;
use rama_tls_rustls::dep::pki_types::pem::PemObject;
use rama_tls_rustls::dep::rcgen::BasicConstraints;
use rama_tls_rustls::dep::rcgen::CertificateParams;
use rama_tls_rustls::dep::rcgen::DistinguishedName;
use rama_tls_rustls::dep::rcgen::DnType;
use rama_tls_rustls::dep::rcgen::ExtendedKeyUsagePurpose;
use rama_tls_rustls::dep::rcgen::IsCa;
use rama_tls_rustls::dep::rcgen::Issuer;
use rama_tls_rustls::dep::rcgen::KeyPair;
use rama_tls_rustls::dep::rcgen::KeyUsagePurpose;
use rama_tls_rustls::dep::rcgen::PKCS_ECDSA_P256_SHA256;
use rama_tls_rustls::dep::rcgen::SanType;
use rama_tls_rustls::dep::rustls;
use rama_tls_rustls::server::TlsAcceptorData;
use std::fs;
use std::fs::File;
use std::fs::OpenOptions;
use std::io::Write;
use std::net::IpAddr;
use std::path::Path;
use std::path::PathBuf;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use tracing::info;
pub(super) struct ManagedMitmCa {
issuer: Issuer<'static, KeyPair>,
}
impl ManagedMitmCa {
pub(super) fn load_or_create() -> Result<Self> {
let (ca_cert_pem, ca_key_pem) = load_or_create_ca()?;
let ca_key = KeyPair::from_pem(&ca_key_pem).context("failed to parse CA key")?;
let issuer: Issuer<'static, KeyPair> =
Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key).context("failed to parse CA cert")?;
Ok(Self { issuer })
}
pub(super) fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
let (cert_pem, key_pem) = issue_host_certificate_pem(host, &self.issuer)?;
let cert = CertificateDer::from_pem_slice(cert_pem.as_bytes())
.context("failed to parse host cert PEM")?;
let key = PrivateKeyDer::from_pem_slice(key_pem.as_bytes())
.context("failed to parse host key PEM")?;
let mut server_config =
rustls::ServerConfig::builder_with_protocol_versions(rustls::ALL_VERSIONS)
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.context("failed to build rustls server config")?;
server_config.alpn_protocols = vec![
ApplicationProtocol::HTTP_2.as_bytes().to_vec(),
ApplicationProtocol::HTTP_11.as_bytes().to_vec(),
];
Ok(TlsAcceptorData::from(server_config))
}
}
fn issue_host_certificate_pem(
host: &str,
issuer: &Issuer<'_, KeyPair>,
) -> Result<(String, String)> {
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
let mut params = CertificateParams::new(Vec::new())
.map_err(|err| anyhow!("failed to create cert params: {err}"))?;
params.subject_alt_names.push(SanType::IpAddress(ip));
params
} else {
CertificateParams::new(vec![host.to_string()])
.map_err(|err| anyhow!("failed to create cert params: {err}"))?
};
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
.map_err(|err| anyhow!("failed to generate host key pair: {err}"))?;
let cert = params
.signed_by(&key_pair, issuer)
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
Ok((cert.pem(), key_pair.serialize_pem()))
}
const MANAGED_MITM_CA_DIR: &str = "proxy";
const MANAGED_MITM_CA_CERT: &str = "ca.pem";
const MANAGED_MITM_CA_KEY: &str = "ca.key";
fn managed_ca_paths() -> Result<(PathBuf, PathBuf)> {
let codex_home =
find_codex_home().context("failed to resolve CODEX_HOME for managed MITM CA")?;
let proxy_dir = codex_home.join(MANAGED_MITM_CA_DIR);
Ok((
proxy_dir.join(MANAGED_MITM_CA_CERT),
proxy_dir.join(MANAGED_MITM_CA_KEY),
))
}
fn load_or_create_ca() -> Result<(String, String)> {
let (cert_path, key_path) = managed_ca_paths()?;
if cert_path.exists() || key_path.exists() {
if !cert_path.exists() || !key_path.exists() {
return Err(anyhow!(
"both managed MITM CA files must exist (cert={}, key={})",
cert_path.display(),
key_path.display()
));
}
validate_existing_ca_key_file(&key_path)?;
let cert_pem = fs::read_to_string(&cert_path)
.with_context(|| format!("failed to read CA cert {}", cert_path.display()))?;
let key_pem = fs::read_to_string(&key_path)
.with_context(|| format!("failed to read CA key {}", key_path.display()))?;
return Ok((cert_pem, key_pem));
}
if let Some(parent) = cert_path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create {}", parent.display()))?;
}
if let Some(parent) = key_path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create {}", parent.display()))?;
}
let (cert_pem, key_pem) = generate_ca()?;
// The CA key is a high-value secret. Create it atomically with restrictive permissions.
// The cert can be world-readable, but we still write it atomically to avoid partial writes.
//
// We intentionally use create-new semantics: if a key already exists, we should not overwrite
// it silently (that would invalidate previously-trusted cert chains).
write_atomic_create_new(&key_path, key_pem.as_bytes(), 0o600)
.with_context(|| format!("failed to persist CA key {}", key_path.display()))?;
if let Err(err) = write_atomic_create_new(&cert_path, cert_pem.as_bytes(), 0o644)
.with_context(|| format!("failed to persist CA cert {}", cert_path.display()))
{
// Avoid leaving a partially-created CA around (cert missing) if the second write fails.
let _ = fs::remove_file(&key_path);
return Err(err);
}
let cert_path = cert_path.display();
let key_path = key_path.display();
info!("generated MITM CA (cert_path={cert_path}, key_path={key_path})");
Ok((cert_pem, key_pem))
}
fn generate_ca() -> Result<(String, String)> {
let mut params = CertificateParams::default();
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "network_proxy MITM CA");
params.distinguished_name = dn;
let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
.map_err(|err| anyhow!("failed to generate CA key pair: {err}"))?;
let cert = params
.self_signed(&key_pair)
.map_err(|err| anyhow!("failed to generate CA cert: {err}"))?;
Ok((cert.pem(), key_pair.serialize_pem()))
}
fn write_atomic_create_new(path: &Path, contents: &[u8], mode: u32) -> Result<()> {
let parent = path
.parent()
.ok_or_else(|| anyhow!("missing parent directory"))?;
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let pid = std::process::id();
let file_name = path.file_name().unwrap_or_default().to_string_lossy();
let tmp_path = parent.join(format!(".{file_name}.tmp.{pid}.{nanos}"));
let mut file = open_create_new_with_mode(&tmp_path, mode)?;
file.write_all(contents)
.with_context(|| format!("failed to write {}", tmp_path.display()))?;
file.sync_all()
.with_context(|| format!("failed to fsync {}", tmp_path.display()))?;
drop(file);
// Create the final file using "create-new" semantics (no overwrite). `rename` on Unix can
// overwrite existing files, so prefer a hard-link, which fails if the destination exists.
match fs::hard_link(&tmp_path, path) {
Ok(()) => {
fs::remove_file(&tmp_path)
.with_context(|| format!("failed to remove {}", tmp_path.display()))?;
}
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
let _ = fs::remove_file(&tmp_path);
return Err(anyhow!(
"refusing to overwrite existing file {}",
path.display()
));
}
Err(_) => {
// Best-effort fallback for environments where hard links are not supported.
// This is still subject to a TOCTOU race, but the typical case is a private per-user
// config directory, where other users cannot create files anyway.
if path.exists() {
let _ = fs::remove_file(&tmp_path);
return Err(anyhow!(
"refusing to overwrite existing file {}",
path.display()
));
}
fs::rename(&tmp_path, path).with_context(|| {
format!(
"failed to rename {} -> {}",
tmp_path.display(),
path.display()
)
})?;
}
}
// Best-effort durability: ensure the directory entry is persisted too.
let dir = File::open(parent).with_context(|| format!("failed to open {}", parent.display()))?;
dir.sync_all()
.with_context(|| format!("failed to fsync {}", parent.display()))?;
Ok(())
}
#[cfg(unix)]
fn validate_existing_ca_key_file(path: &Path) -> Result<()> {
use std::os::unix::fs::PermissionsExt;
let metadata = fs::symlink_metadata(path)
.with_context(|| format!("failed to stat CA key {}", path.display()))?;
if metadata.file_type().is_symlink() {
return Err(anyhow!(
"refusing to use symlink for managed MITM CA key {}",
path.display()
));
}
if !metadata.is_file() {
return Err(anyhow!(
"managed MITM CA key is not a regular file: {}",
path.display()
));
}
let mode = metadata.permissions().mode() & 0o777;
if mode & 0o077 != 0 {
return Err(anyhow!(
"managed MITM CA key {} must not be group/world accessible (mode={mode:o}; expected <= 600)",
path.display()
));
}
Ok(())
}
#[cfg(not(unix))]
fn validate_existing_ca_key_file(_path: &Path) -> Result<()> {
Ok(())
}
#[cfg(unix)]
fn open_create_new_with_mode(path: &Path, mode: u32) -> Result<File> {
use std::os::unix::fs::OpenOptionsExt;
OpenOptions::new()
.write(true)
.create_new(true)
.mode(mode)
.open(path)
.with_context(|| format!("failed to create {}", path.display()))
}
#[cfg(not(unix))]
fn open_create_new_with_mode(path: &Path, _mode: u32) -> Result<File> {
OpenOptions::new()
.write(true)
.create_new(true)
.open(path)
.with_context(|| format!("failed to create {}", path.display()))
}
#[cfg(all(test, unix))]
mod tests {
use super::*;
use std::os::unix::fs::PermissionsExt;
use tempfile::tempdir;
#[test]
fn validate_existing_ca_key_file_rejects_group_world_permissions() {
let dir = tempdir().unwrap();
let key_path = dir.path().join("ca.key");
fs::write(&key_path, "key").unwrap();
fs::set_permissions(&key_path, fs::Permissions::from_mode(0o644)).unwrap();
let err = validate_existing_ca_key_file(&key_path).unwrap_err();
assert!(
err.to_string().contains("group/world accessible"),
"unexpected error: {err:#}"
);
}
#[test]
fn validate_existing_ca_key_file_rejects_symlink() {
use std::os::unix::fs::symlink;
let dir = tempdir().unwrap();
let target = dir.path().join("real.key");
let link = dir.path().join("ca.key");
fs::write(&target, "key").unwrap();
symlink(&target, &link).unwrap();
let err = validate_existing_ca_key_file(&link).unwrap_err();
assert!(
err.to_string().contains("symlink"),
"unexpected error: {err:#}"
);
}
#[test]
fn validate_existing_ca_key_file_allows_private_permissions() {
let dir = tempdir().unwrap();
let key_path = dir.path().join("ca.key");
fs::write(&key_path, "key").unwrap();
fs::set_permissions(&key_path, fs::Permissions::from_mode(0o600)).unwrap();
validate_existing_ca_key_file(&key_path).unwrap();
}
}

View File

@@ -45,6 +45,8 @@ pub struct NetworkProxySettings {
#[serde(default)]
pub allow_unix_sockets: Vec<String>,
pub allow_local_binding: bool,
#[serde(default)]
pub mitm: bool,
}
impl Default for NetworkProxySettings {
@@ -65,6 +67,7 @@ impl Default for NetworkProxySettings {
denied_domains: Vec::new(),
allow_unix_sockets: Vec::new(),
allow_local_binding: true,
mitm: false,
}
}
}
@@ -74,6 +77,7 @@ impl Default for NetworkProxySettings {
pub enum NetworkMode {
/// Limited (read-only) access: only GET/HEAD/OPTIONS are allowed for HTTP. HTTPS CONNECT is
/// blocked unless MITM is enabled so the proxy can enforce method policy on inner requests.
/// SOCKS5 remains blocked in limited mode.
Limited,
/// Full network access: all HTTP methods are allowed, and HTTPS CONNECTs are tunneled without
/// MITM interception.
@@ -393,6 +397,7 @@ mod tests {
denied_domains: Vec::new(),
allow_unix_sockets: Vec::new(),
allow_local_binding: true,
mitm: false,
}
);
}

View File

@@ -1,4 +1,5 @@
use crate::config::NetworkMode;
use crate::mitm;
use crate::network_policy::NetworkDecision;
use crate::network_policy::NetworkDecisionSource;
use crate::network_policy::NetworkPolicyDecider;
@@ -9,6 +10,7 @@ use crate::network_policy::NetworkProtocol;
use crate::network_policy::evaluate_host_policy;
use crate::policy::normalize_host;
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
use crate::reasons::REASON_MITM_REQUIRED;
use crate::reasons::REASON_NOT_ALLOWED;
use crate::reasons::REASON_PROXY_DISABLED;
use crate::responses::PolicyDecisionDetails;
@@ -49,6 +51,7 @@ use rama_http_backend::server::HttpServer;
use rama_http_backend::server::layer::upgrade::UpgradeLayer;
use rama_http_backend::server::layer::upgrade::Upgraded;
use rama_net::Protocol;
use rama_net::address::HostWithOptPort;
use rama_net::address::ProxyAddress;
use rama_net::client::ConnectorService;
use rama_net::client::EstablishedClientConnection;
@@ -233,10 +236,20 @@ async fn http_connect_accept(
.await
.map_err(|err| internal_error("failed to read network mode", err))?;
if mode == NetworkMode::Limited {
let mitm_state = match app_state.mitm_state().await {
Ok(state) => state,
Err(err) => {
error!("failed to load MITM state: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
if mode == NetworkMode::Limited && mitm_state.is_none() {
// Limited mode is designed to be read-only. Without MITM, a CONNECT tunnel would hide the
// inner HTTP method/headers from the proxy, effectively bypassing method policy.
let details = PolicyDecisionDetails {
decision: NetworkPolicyDecision::Deny,
reason: REASON_METHOD_NOT_ALLOWED,
reason: REASON_MITM_REQUIRED,
source: NetworkDecisionSource::ModeGuard,
protocol: NetworkProtocol::HttpsConnect,
host: &host,
@@ -245,7 +258,7 @@ async fn http_connect_accept(
let _ = app_state
.record_blocked(BlockedRequest::new(BlockedRequestArgs {
host: host.clone(),
reason: REASON_METHOD_NOT_ALLOWED.to_string(),
reason: REASON_MITM_REQUIRED.to_string(),
client: client.clone(),
method: Some("CONNECT".to_string()),
mode: Some(NetworkMode::Limited),
@@ -256,15 +269,17 @@ async fn http_connect_accept(
}))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("CONNECT blocked by method policy (client={client}, host={host}, mode=limited)");
return Err(blocked_text_with_details(
REASON_METHOD_NOT_ALLOWED,
&details,
));
warn!(
"CONNECT blocked; MITM required for read-only HTTPS in limited mode (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Err(blocked_text_with_details(REASON_MITM_REQUIRED, &details));
}
req.extensions_mut().insert(ProxyTarget(authority));
req.extensions_mut().insert(mode);
if let Some(mitm_state) = mitm_state {
req.extensions_mut().insert(mitm_state);
}
Ok((
Response::builder()
@@ -276,9 +291,34 @@ async fn http_connect_accept(
}
async fn http_connect_proxy(upgraded: Upgraded) -> Result<(), Infallible> {
if upgraded.extensions().get::<ProxyTarget>().is_none() {
let mode = upgraded
.extensions()
.get::<NetworkMode>()
.copied()
.unwrap_or(NetworkMode::Full);
let Some(target) = upgraded
.extensions()
.get::<ProxyTarget>()
.map(|t| t.0.clone())
else {
warn!("CONNECT missing proxy target");
return Ok(());
};
if mode == NetworkMode::Limited
&& upgraded
.extensions()
.get::<Arc<mitm::MitmState>>()
.is_some()
{
let host = normalize_host(&target.host.to_string());
let port = target.port;
info!("CONNECT MITM enabled (host={host}, port={port}, mode={mode:?})");
if let Err(err) = mitm::mitm_tunnel(upgraded).await {
warn!("MITM tunnel error: {err}");
}
return Ok(());
}
let allow_upstream_proxy = match upgraded
@@ -465,6 +505,10 @@ async fn http_plain_proxy(
};
let host = normalize_host(&authority.host.to_string());
let port = authority.port;
if let Err(err) = validate_plain_http_host_header(&req, &authority) {
warn!("HTTP request host mismatch: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "host mismatch"));
}
let enabled = match app_state
.enabled()
.await
@@ -669,6 +713,45 @@ fn remove_hop_by_hop_request_headers(headers: &mut HeaderMap) {
}
}
fn validate_plain_http_host_header(
req: &Request,
target: &rama_net::address::HostWithPort,
) -> std::result::Result<(), &'static str> {
// Only enforce this in absolute-form requests. Origin-form requests use the Host header as the
// routing authority, so there is no separate target authority to compare against.
if req.uri().authority().is_none() {
return Ok(());
}
let Some(raw_host) = req.headers().get(header::HOST) else {
return Ok(());
};
let raw_host = raw_host.to_str().map_err(|_| "invalid Host header")?;
let parsed = HostWithOptPort::try_from(raw_host).map_err(|_| "invalid Host header")?;
let target_host = normalize_host(&target.host.to_string());
let request_host = normalize_host(&parsed.host.to_string());
if request_host.is_empty() || request_host != target_host {
return Err("request Host header host does not match target authority");
}
let expected_port = target.port;
let request_port = match parsed.port {
Some(port) => port,
None => match req.uri().scheme_str() {
Some("http") => 80,
Some("https") => 443,
Some(_) | None => expected_port,
},
};
if request_port != expected_port {
return Err("request Host header port does not match target authority");
}
Ok(())
}
fn json_blocked(host: &str, reason: &str, details: Option<&PolicyDecisionDetails<'_>>) -> Response {
let (message, decision, source, protocol, port) = details
.map(|details| {
@@ -804,7 +887,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::FORBIDDEN);
assert_eq!(
response.headers().get("x-proxy-error").unwrap(),
"blocked-by-method-policy"
"blocked-by-mitm-required"
);
}
@@ -853,6 +936,23 @@ mod tests {
);
}
#[tokio::test]
async fn http_plain_proxy_rejects_absolute_uri_host_header_mismatch() {
let state = Arc::new(network_proxy_state_for_policy(
NetworkProxySettings::default(),
));
let mut req = Request::builder()
.method(Method::GET)
.uri("http://raw.githubusercontent.com/openai/codex/main/README.md")
.header(header::HOST, "api.github.com")
.body(Body::empty())
.unwrap();
req.extensions_mut().insert(state);
let response = http_plain_proxy(None, req).await;
assert_eq!(response.unwrap().status(), StatusCode::BAD_REQUEST);
}
#[test]
fn remove_hop_by_hop_request_headers_keeps_forwarding_headers() {
let mut headers = HeaderMap::new();

View File

@@ -1,8 +1,10 @@
#![deny(clippy::print_stdout, clippy::print_stderr)]
mod admin;
mod certs;
mod config;
mod http_proxy;
mod mitm;
mod network_policy;
mod policy;
mod proxy;

View File

@@ -0,0 +1,482 @@
use crate::certs::ManagedMitmCa;
use crate::config::NetworkMode;
use crate::policy::normalize_host;
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
use crate::responses::blocked_text_response;
use crate::responses::text_response;
use crate::runtime::HostBlockDecision;
use crate::runtime::HostBlockReason;
use crate::state::BlockedRequest;
use crate::state::BlockedRequestArgs;
use crate::state::NetworkProxyState;
use crate::upstream::UpstreamClient;
use anyhow::Context as _;
use anyhow::Result;
use anyhow::anyhow;
use rama_core::Layer;
use rama_core::Service;
use rama_core::bytes::Bytes;
use rama_core::error::BoxError;
use rama_core::extensions::ExtensionsRef;
use rama_core::futures::stream::Stream;
use rama_core::rt::Executor;
use rama_core::service::service_fn;
use rama_http::Body;
use rama_http::BodyDataStream;
use rama_http::HeaderValue;
use rama_http::Request;
use rama_http::Response;
use rama_http::StatusCode;
use rama_http::Uri;
use rama_http::header::HOST;
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
use rama_http_backend::server::HttpServer;
use rama_http_backend::server::layer::upgrade::Upgraded;
use rama_net::proxy::ProxyTarget;
use rama_net::stream::SocketInfo;
use rama_tls_rustls::server::TlsAcceptorData;
use rama_tls_rustls::server::TlsAcceptorLayer;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context as TaskContext;
use std::task::Poll;
use tracing::info;
use tracing::warn;
/// State needed to terminate a CONNECT tunnel and enforce policy on inner HTTPS requests.
pub struct MitmState {
ca: ManagedMitmCa,
upstream: UpstreamClient,
inspect: bool,
max_body_bytes: usize,
}
#[derive(Clone)]
struct MitmPolicyContext {
target_host: String,
target_port: u16,
mode: NetworkMode,
app_state: Arc<NetworkProxyState>,
}
#[derive(Clone)]
struct MitmRequestContext {
policy: MitmPolicyContext,
mitm: Arc<MitmState>,
}
const MITM_INSPECT_BODIES: bool = false;
const MITM_MAX_BODY_BYTES: usize = 4096;
impl std::fmt::Debug for MitmState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Avoid dumping internal state (CA material, connectors, etc.) to logs.
f.debug_struct("MitmState")
.field("inspect", &self.inspect)
.field("max_body_bytes", &self.max_body_bytes)
.finish_non_exhaustive()
}
}
impl MitmState {
pub(crate) fn new(allow_upstream_proxy: bool) -> Result<Self> {
// MITM exists to make limited-mode HTTPS enforceable: once CONNECT is established, plain
// proxying would lose visibility into the inner HTTP request. We generate/load a local CA
// and issue per-host leaf certs so we can terminate TLS and apply policy.
let ca = ManagedMitmCa::load_or_create()?;
let upstream = if allow_upstream_proxy {
UpstreamClient::from_env_proxy()
} else {
UpstreamClient::direct()
};
Ok(Self {
ca,
upstream,
inspect: MITM_INSPECT_BODIES,
max_body_bytes: MITM_MAX_BODY_BYTES,
})
}
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
self.ca.tls_acceptor_data_for_host(host)
}
pub(crate) fn inspect_enabled(&self) -> bool {
self.inspect
}
pub(crate) fn max_body_bytes(&self) -> usize {
self.max_body_bytes
}
}
/// Terminate the upgraded CONNECT stream with a generated leaf cert and proxy inner HTTPS traffic.
pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
let mitm = upgraded
.extensions()
.get::<Arc<MitmState>>()
.cloned()
.context("missing MITM state")?;
let app_state = upgraded
.extensions()
.get::<Arc<NetworkProxyState>>()
.cloned()
.context("missing app state")?;
let target = upgraded
.extensions()
.get::<ProxyTarget>()
.context("missing proxy target")?
.0
.clone();
let target_host = normalize_host(&target.host.to_string());
let target_port = target.port;
let acceptor_data = mitm.tls_acceptor_data_for_host(&target_host)?;
let mode = upgraded
.extensions()
.get::<NetworkMode>()
.copied()
.unwrap_or(NetworkMode::Full);
let request_ctx = Arc::new(MitmRequestContext {
policy: MitmPolicyContext {
target_host,
target_port,
mode,
app_state,
},
mitm,
});
let executor = upgraded
.extensions()
.get::<Executor>()
.cloned()
.unwrap_or_default();
let http_service = HttpServer::auto(executor).service(
(
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
)
.into_layer(service_fn({
let request_ctx = request_ctx.clone();
move |req| {
let request_ctx = request_ctx.clone();
async move { handle_mitm_request(req, request_ctx).await }
}
})),
);
let https_service = TlsAcceptorLayer::new(acceptor_data)
.with_store_client_hello(true)
.into_layer(http_service);
https_service
.serve(upgraded)
.await
.map_err(|err| anyhow!("MITM serve error: {err}"))?;
Ok(())
}
async fn handle_mitm_request(
req: Request,
request_ctx: Arc<MitmRequestContext>,
) -> Result<Response, std::convert::Infallible> {
let response = match forward_request(req, &request_ctx).await {
Ok(resp) => resp,
Err(err) => {
warn!("MITM request handling failed: {err}");
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
}
};
Ok(response)
}
async fn forward_request(req: Request, request_ctx: &MitmRequestContext) -> Result<Response> {
if let Some(response) = mitm_blocking_response(&req, &request_ctx.policy).await? {
return Ok(response);
}
let target_host = request_ctx.policy.target_host.clone();
let target_port = request_ctx.policy.target_port;
let mitm = request_ctx.mitm.clone();
let method = req.method().as_str().to_string();
let path = path_and_query(req.uri());
let log_path = path_for_log(req.uri());
let (mut parts, body) = req.into_parts();
let authority = authority_header_value(&target_host, target_port);
parts.uri = build_https_uri(&authority, &path)?;
parts
.headers
.insert(HOST, HeaderValue::from_str(&authority)?);
let inspect = mitm.inspect_enabled();
let max_body_bytes = mitm.max_body_bytes();
let body = if inspect {
inspect_body(
body,
max_body_bytes,
RequestLogContext {
host: authority.clone(),
method: method.clone(),
path: log_path.clone(),
},
)
} else {
body
};
let upstream_req = Request::from_parts(parts, body);
let upstream_resp = mitm.upstream.serve(upstream_req).await?;
respond_with_inspection(
upstream_resp,
inspect,
max_body_bytes,
&method,
&log_path,
&authority,
)
}
async fn mitm_blocking_response(
req: &Request,
policy: &MitmPolicyContext,
) -> Result<Option<Response>> {
if req.method().as_str() == "CONNECT" {
return Ok(Some(text_response(
StatusCode::METHOD_NOT_ALLOWED,
"CONNECT not supported inside MITM",
)));
}
let method = req.method().as_str().to_string();
let log_path = path_for_log(req.uri());
let client = req
.extensions()
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
if let Some(request_host) = extract_request_host(req) {
let normalized = normalize_host(&request_host);
if !normalized.is_empty() && normalized != policy.target_host {
warn!(
"MITM host mismatch (target={}, request_host={normalized})",
policy.target_host
);
return Ok(Some(text_response(
StatusCode::BAD_REQUEST,
"host mismatch",
)));
}
}
// CONNECT already handled allowlist/denylist + decider policy. Re-check local/private
// resolution here to defend against DNS rebinding between CONNECT and inner HTTPS requests.
if matches!(
policy
.app_state
.host_blocked(&policy.target_host, policy.target_port)
.await?,
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
) {
let reason = HostBlockReason::NotAllowedLocal.as_str();
let _ = policy
.app_state
.record_blocked(BlockedRequest::new(BlockedRequestArgs {
host: policy.target_host.clone(),
reason: reason.to_string(),
client: client.clone(),
method: Some(method.clone()),
mode: Some(policy.mode),
protocol: "https".to_string(),
decision: None,
source: None,
port: Some(policy.target_port),
}))
.await;
warn!(
"MITM blocked local/private target after CONNECT (host={}, port={}, method={method}, path={log_path})",
policy.target_host, policy.target_port
);
return Ok(Some(blocked_text_response(reason)));
}
if !policy.mode.allows_method(&method) {
let _ = policy
.app_state
.record_blocked(BlockedRequest::new(BlockedRequestArgs {
host: policy.target_host.clone(),
reason: REASON_METHOD_NOT_ALLOWED.to_string(),
client: client.clone(),
method: Some(method.clone()),
mode: Some(policy.mode),
protocol: "https".to_string(),
decision: None,
source: None,
port: Some(policy.target_port),
}))
.await;
warn!(
"MITM blocked by method policy (host={}, method={method}, path={log_path}, mode={:?}, allowed_methods=GET, HEAD, OPTIONS)",
policy.target_host, policy.mode
);
return Ok(Some(blocked_text_response(REASON_METHOD_NOT_ALLOWED)));
}
Ok(None)
}
fn respond_with_inspection(
resp: Response,
inspect: bool,
max_body_bytes: usize,
method: &str,
log_path: &str,
authority: &str,
) -> Result<Response> {
if !inspect {
return Ok(resp);
}
let (parts, body) = resp.into_parts();
let body = inspect_body(
body,
max_body_bytes,
ResponseLogContext {
host: authority.to_string(),
method: method.to_string(),
path: log_path.to_string(),
status: parts.status,
},
);
Ok(Response::from_parts(parts, body))
}
fn inspect_body<T: BodyLoggable + Send + 'static>(
body: Body,
max_body_bytes: usize,
ctx: T,
) -> Body {
Body::from_stream(InspectStream {
inner: Box::pin(body.into_data_stream()),
ctx: Some(Box::new(ctx)),
len: 0,
max_body_bytes,
})
}
struct InspectStream<T> {
inner: Pin<Box<BodyDataStream>>,
ctx: Option<Box<T>>,
len: usize,
max_body_bytes: usize,
}
impl<T: BodyLoggable> Stream for InspectStream<T> {
type Item = Result<Bytes, BoxError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.len = this.len.saturating_add(bytes.len());
Poll::Ready(Some(Ok(bytes)))
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => {
if let Some(ctx) = this.ctx.take() {
ctx.log(this.len, this.len > this.max_body_bytes);
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
struct RequestLogContext {
host: String,
method: String,
path: String,
}
struct ResponseLogContext {
host: String,
method: String,
path: String,
status: StatusCode,
}
trait BodyLoggable {
fn log(self, len: usize, truncated: bool);
}
impl BodyLoggable for RequestLogContext {
fn log(self, len: usize, truncated: bool) {
let host = self.host;
let method = self.method;
let path = self.path;
info!(
"MITM inspected request body (host={host}, method={method}, path={path}, body_len={len}, truncated={truncated})"
);
}
}
impl BodyLoggable for ResponseLogContext {
fn log(self, len: usize, truncated: bool) {
let host = self.host;
let method = self.method;
let path = self.path;
let status = self.status;
info!(
"MITM inspected response body (host={host}, method={method}, path={path}, status={status}, body_len={len}, truncated={truncated})"
);
}
}
fn extract_request_host(req: &Request) -> Option<String> {
req.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(ToString::to_string)
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
}
fn authority_header_value(host: &str, port: u16) -> String {
// Host header / URI authority formatting.
if host.contains(':') {
if port == 443 {
format!("[{host}]")
} else {
format!("[{host}]:{port}")
}
} else if port == 443 {
host.to_string()
} else {
format!("{host}:{port}")
}
}
fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
let target = format!("https://{authority}{path}");
Ok(target.parse()?)
}
fn path_and_query(uri: &Uri) -> String {
uri.path_and_query()
.map(rama_http::uri::PathAndQuery::as_str)
.unwrap_or("/")
.to_string()
}
fn path_for_log(uri: &Uri) -> String {
uri.path().to_string()
}
#[cfg(test)]
#[path = "mitm_tests.rs"]
mod tests;

View File

@@ -0,0 +1,110 @@
use super::*;
use crate::config::NetworkProxySettings;
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
use crate::runtime::network_proxy_state_for_policy;
use pretty_assertions::assert_eq;
use rama_http::Body;
use rama_http::Method;
use rama_http::Request;
use rama_http::StatusCode;
fn policy_ctx(
app_state: Arc<NetworkProxyState>,
mode: NetworkMode,
target_host: &str,
target_port: u16,
) -> MitmPolicyContext {
MitmPolicyContext {
target_host: target_host.to_string(),
target_port,
mode,
app_state,
}
}
#[tokio::test]
async fn mitm_policy_blocks_disallowed_method_and_records_telemetry() {
let app_state = Arc::new(network_proxy_state_for_policy(NetworkProxySettings {
allowed_domains: vec!["example.com".to_string()],
..NetworkProxySettings::default()
}));
let ctx = policy_ctx(app_state.clone(), NetworkMode::Limited, "example.com", 443);
let req = Request::builder()
.method(Method::POST)
.uri("/v1/responses?api_key=secret")
.header(HOST, "example.com")
.body(Body::empty())
.unwrap();
let response = mitm_blocking_response(&req, &ctx)
.await
.unwrap()
.expect("POST should be blocked in limited mode");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
assert_eq!(
response.headers().get("x-proxy-error").unwrap(),
"blocked-by-method-policy"
);
let blocked = app_state.drain_blocked().await.unwrap();
assert_eq!(blocked.len(), 1);
assert_eq!(blocked[0].reason, REASON_METHOD_NOT_ALLOWED);
assert_eq!(blocked[0].method.as_deref(), Some("POST"));
assert_eq!(blocked[0].host, "example.com");
assert_eq!(blocked[0].port, Some(443));
}
#[tokio::test]
async fn mitm_policy_rejects_host_mismatch() {
let app_state = Arc::new(network_proxy_state_for_policy(NetworkProxySettings {
allowed_domains: vec!["example.com".to_string()],
..NetworkProxySettings::default()
}));
let ctx = policy_ctx(app_state.clone(), NetworkMode::Full, "example.com", 443);
let req = Request::builder()
.method(Method::GET)
.uri("/")
.header(HOST, "evil.example")
.body(Body::empty())
.unwrap();
let response = mitm_blocking_response(&req, &ctx)
.await
.unwrap()
.expect("mismatched host should be rejected");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert_eq!(app_state.blocked_snapshot().await.unwrap().len(), 0);
}
#[tokio::test]
async fn mitm_policy_rechecks_local_private_target_after_connect() {
let app_state = Arc::new(network_proxy_state_for_policy(NetworkProxySettings {
allowed_domains: vec!["*".to_string()],
allow_local_binding: false,
..NetworkProxySettings::default()
}));
let ctx = policy_ctx(app_state.clone(), NetworkMode::Full, "10.0.0.1", 443);
let req = Request::builder()
.method(Method::GET)
.uri("/health?token=secret")
.header(HOST, "10.0.0.1")
.body(Body::empty())
.unwrap();
let response = mitm_blocking_response(&req, &ctx)
.await
.unwrap()
.expect("local/private target should be blocked on inner request");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
let blocked = app_state.drain_blocked().await.unwrap();
assert_eq!(blocked.len(), 1);
assert_eq!(blocked[0].reason, REASON_NOT_ALLOWED_LOCAL);
assert_eq!(blocked[0].host, "10.0.0.1");
assert_eq!(blocked[0].port, Some(443));
}

View File

@@ -1,5 +1,6 @@
pub(crate) const REASON_DENIED: &str = "denied";
pub(crate) const REASON_METHOD_NOT_ALLOWED: &str = "method_not_allowed";
pub(crate) const REASON_MITM_REQUIRED: &str = "mitm_required";
pub(crate) const REASON_NOT_ALLOWED: &str = "not_allowed";
pub(crate) const REASON_NOT_ALLOWED_LOCAL: &str = "not_allowed_local";
pub(crate) const REASON_POLICY_DENIED: &str = "policy_denied";

View File

@@ -3,6 +3,7 @@ use crate::network_policy::NetworkPolicyDecision;
use crate::network_policy::NetworkProtocol;
use crate::reasons::REASON_DENIED;
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
use crate::reasons::REASON_MITM_REQUIRED;
use crate::reasons::REASON_NOT_ALLOWED;
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
use rama_http::Body;
@@ -51,6 +52,7 @@ pub fn blocked_header_value(reason: &str) -> &'static str {
REASON_NOT_ALLOWED | REASON_NOT_ALLOWED_LOCAL => "blocked-by-allowlist",
REASON_DENIED => "blocked-by-denylist",
REASON_METHOD_NOT_ALLOWED => "blocked-by-method-policy",
REASON_MITM_REQUIRED => "blocked-by-mitm-required",
_ => "blocked-by-policy",
}
}
@@ -67,10 +69,19 @@ pub fn blocked_message(reason: &str) -> &'static str {
REASON_METHOD_NOT_ALLOWED => {
"Codex blocked this request: method not allowed in limited mode."
}
REASON_MITM_REQUIRED => "Codex blocked this request: MITM required for limited HTTPS.",
_ => "Codex blocked this request by network policy.",
}
}
pub fn blocked_text_response(reason: &str) -> Response {
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "text/plain")
.header("x-proxy-error", blocked_header_value(reason))
.body(Body::from(blocked_message(reason)))
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}
pub fn blocked_message_with_policy(reason: &str, details: &PolicyDecisionDetails<'_>) -> String {
let _ = (details.reason, details.host);
blocked_message(reason).to_string()

View File

@@ -1,6 +1,7 @@
use crate::config::NetworkMode;
use crate::config::NetworkProxyConfig;
use crate::config::ValidatedUnixSocketPath;
use crate::mitm::MitmState;
use crate::policy::Host;
use crate::policy::is_loopback_host;
use crate::policy::is_non_public_ip;
@@ -141,6 +142,7 @@ pub struct ConfigState {
pub config: NetworkProxyConfig,
pub allow_set: GlobSet,
pub deny_set: GlobSet,
pub mitm: Option<Arc<MitmState>>,
pub constraints: NetworkProxyConstraints,
pub blocked: VecDeque<BlockedRequest>,
pub blocked_total: u64,
@@ -499,6 +501,12 @@ impl NetworkProxyState {
}
}
pub async fn mitm_state(&self) -> Result<Option<Arc<MitmState>>> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.mitm.clone())
}
pub async fn add_allowed_domain(&self, host: &str) -> Result<()> {
self.update_domain_list(host, DomainListKind::Allow).await
}

View File

@@ -1,10 +1,12 @@
use crate::config::NetworkMode;
use crate::config::NetworkProxyConfig;
use crate::mitm::MitmState;
use crate::policy::DomainPattern;
use crate::policy::compile_globset;
use crate::runtime::ConfigState;
use serde::Deserialize;
use std::collections::HashSet;
use std::sync::Arc;
pub use crate::runtime::BlockedRequest;
pub use crate::runtime::BlockedRequestArgs;
@@ -57,10 +59,18 @@ pub fn build_config_state(
crate::config::validate_unix_socket_allowlist_paths(&config)?;
let deny_set = compile_globset(&config.network.denied_domains)?;
let allow_set = compile_globset(&config.network.allowed_domains)?;
let mitm = if config.network.mitm {
Some(Arc::new(MitmState::new(
config.network.allow_upstream_proxy,
)?))
} else {
None
};
Ok(ConfigState {
config,
allow_set,
deny_set,
mitm,
constraints,
blocked: std::collections::VecDeque::new(),
blocked_total: 0,