Update network proxy rama deps

This commit is contained in:
viyatb-oai
2026-01-12 15:25:25 -08:00
parent 310c79eef5
commit d2042b92b6
30 changed files with 6010 additions and 253 deletions

View File

@@ -2,8 +2,8 @@ use crate::config::NetworkMode;
use crate::responses::json_response;
use crate::responses::text_response;
use crate::state::AppState;
use anyhow::Context;
use anyhow::Result;
use rama::Context as RamaContext;
use rama::http::Body;
use rama::http::Request;
use rama::http::Response;
@@ -19,28 +19,30 @@ use std::sync::Arc;
use tracing::error;
use tracing::info;
type ContextState = Arc<AppState>;
type AdminContext = RamaContext<ContextState>;
pub async fn run_admin_api(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
// Debug-only admin API (health/config/patterns/blocked + mode/reload). Policy is config-driven
// and constraint-enforced; this endpoint should not become a second policy/approval plane.
let listener = TcpListener::build_with_state(state)
let listener = TcpListener::build()
.bind(addr)
.await
.map_err(|err| anyhow::anyhow!("bind admin API: {err}"))?;
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
.map_err(rama::error::OpaqueError::from)
.map_err(anyhow::Error::from)
.with_context(|| format!("bind admin API: {addr}"))?;
let server =
HttpServer::auto(rama::rt::Executor::new()).service(service_fn(handle_admin_request));
let server_state = state.clone();
let server = HttpServer::auto(rama::rt::Executor::new()).service(service_fn(move |req| {
let state = server_state.clone();
async move { handle_admin_request(state, req).await }
}));
info!("admin API listening on {addr}");
listener.serve(server).await;
Ok(())
}
async fn handle_admin_request(ctx: AdminContext, req: Request) -> Result<Response, Infallible> {
async fn handle_admin_request(state: Arc<AppState>, req: Request) -> Result<Response, Infallible> {
const MODE_BODY_LIMIT: usize = 8 * 1024;
let state = ctx.state().clone();
let method = req.method().clone();
let path = req.uri().path().to_string();
let response = match (method.as_str(), path.as_str()) {

View File

@@ -341,20 +341,23 @@ mod tests {
fn resolve_addr_maps_localhost_to_loopback() {
assert_eq!(
resolve_addr("localhost", 3128),
"127.0.0.1:3128".parse().unwrap()
"127.0.0.1:3128".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn resolve_addr_parses_ip_literals() {
assert_eq!(resolve_addr("1.2.3.4", 80), "1.2.3.4:80".parse().unwrap());
assert_eq!(
resolve_addr("1.2.3.4", 80),
"1.2.3.4:80".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn resolve_addr_parses_ipv6_literals() {
assert_eq!(
resolve_addr("http://[::1]:8080", 3128),
"[::1]:8080".parse().unwrap()
"[::1]:8080".parse::<SocketAddr>().unwrap()
);
}
@@ -362,7 +365,7 @@ mod tests {
fn resolve_addr_falls_back_to_loopback_for_hostnames() {
assert_eq!(
resolve_addr("http://example.com:5555", 3128),
"127.0.0.1:5555".parse().unwrap()
"127.0.0.1:5555".parse::<SocketAddr>().unwrap()
);
}
}

View File

@@ -4,9 +4,9 @@ use crate::policy::normalize_host;
use crate::responses::blocked_header_value;
use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Context;
use anyhow::Context as _;
use anyhow::Result;
use rama::Context as RamaContext;
use rama::Context;
use rama::Layer;
use rama::Service;
use rama::http::Body;
@@ -14,13 +14,13 @@ use rama::http::Request;
use rama::http::Response;
use rama::http::StatusCode;
use rama::http::client::EasyHttpWebClient;
use rama::http::dep::http::uri::PathAndQuery;
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
use rama::http::layer::upgrade::UpgradeLayer;
use rama::http::layer::upgrade::Upgraded;
use rama::http::matcher::MethodMatcher;
use rama::http::server::HttpServer;
use rama::layer::AddExtensionLayer;
use rama::net::http::RequestContext;
use rama::net::proxy::ProxyTarget;
use rama::net::stream::SocketInfo;
@@ -35,14 +35,16 @@ use tracing::error;
use tracing::info;
use tracing::warn;
type ContextState = Arc<AppState>;
type ProxyContext = RamaContext<ContextState>;
pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::build_with_state(state)
let listener = TcpListener::build()
.bind(addr)
.await
.map_err(|err| anyhow::anyhow!(err))
// Rama's `BoxError` is a `Box<dyn Error + Send + Sync>` without an explicit `'static`
// lifetime bound, which means it doesn't satisfy `anyhow::Context`'s `StdError` constraint.
// Wrap it in Rama's `OpaqueError` so we can preserve the original error as a source and
// still use `anyhow` for chaining.
.map_err(rama::error::OpaqueError::from)
.map_err(anyhow::Error::from)
.with_context(|| format!("bind HTTP proxy: {addr}"))?;
let http_service = HttpServer::auto(rama::rt::Executor::new()).service(
@@ -60,14 +62,24 @@ pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()
info!("HTTP proxy listening on {addr}");
listener.serve(http_service).await;
listener
.serve(AddExtensionLayer::new(state).into_layer(http_service))
.await;
Ok(())
}
async fn http_connect_accept(
mut ctx: ProxyContext,
async fn http_connect_accept<S>(
mut ctx: Context<S>,
req: Request,
) -> Result<(Response, ProxyContext, Request), Response> {
) -> Result<(Response, Context<S>, Request), Response>
where
S: Clone + Send + Sync + 'static,
{
let app_state = ctx
.get::<Arc<AppState>>()
.cloned()
.ok_or_else(|| text_response(StatusCode::INTERNAL_SERVER_ERROR, "missing state"))?;
let authority = match ctx
.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
.map(|ctx| ctx.authority.clone())
@@ -84,7 +96,6 @@ async fn http_connect_accept(
return Err(text_response(StatusCode::BAD_REQUEST, "invalid host"));
}
let app_state = ctx.state().clone();
let client = client_addr(&ctx);
match app_state.host_blocked(&host, authority.port()).await {
@@ -165,33 +176,25 @@ async fn http_connect_accept(
))
}
async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(), Infallible> {
async fn http_connect_proxy<S>(ctx: Context<S>, upgraded: Upgraded) -> Result<(), Infallible>
where
S: Clone + Send + Sync + 'static,
{
let mode = ctx
.get::<NetworkMode>()
.copied()
.unwrap_or(NetworkMode::Full);
let authority = match ctx.get::<ProxyTarget>().map(|target| target.0.clone()) {
Some(authority) => authority,
None => {
warn!("CONNECT missing proxy target");
return Ok(());
}
};
let host = normalize_host(&authority.host().to_string());
if let Some(mitm_state) = ctx.get::<Arc<mitm::MitmState>>().cloned() {
let port = authority.port();
let Some(target) = ctx.get::<ProxyTarget>().map(|t| t.0.clone()) else {
warn!("CONNECT missing proxy target");
return Ok(());
};
let host = normalize_host(&target.host().to_string());
if ctx.get::<Arc<mitm::MitmState>>().is_some() {
let port = target.port();
info!("CONNECT MITM enabled (host={host}, port={port}, mode={mode:?})");
if let Err(err) = mitm::mitm_tunnel(
ctx,
upgraded,
host.as_str(),
authority.port(),
mode,
mitm_state,
)
.await
{
if let Err(err) = mitm::mitm_tunnel(ctx, upgraded).await {
warn!("MITM tunnel error: {err}");
}
return Ok(());
@@ -204,8 +207,17 @@ async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(),
Ok(())
}
async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Response, Infallible> {
let app_state = ctx.state().clone();
async fn http_plain_proxy<S>(mut ctx: Context<S>, req: Request) -> Result<Response, Infallible>
where
S: Clone + Send + Sync + 'static,
{
let app_state = match ctx.get::<Arc<AppState>>().cloned() {
Some(state) => state,
None => {
error!("missing app state");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
let client = client_addr(&ctx);
let method_allowed = match app_state.method_allowed(req.method().as_str()).await {
@@ -216,15 +228,20 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
}
};
if let Some(socket_path) = req
// `x-unix-socket` is an escape hatch for talking to local daemons. We keep it tightly
// scoped: macOS-only + explicit allowlist, to avoid turning the proxy into a general local
// capability-escalation mechanism.
.headers()
.get("x-unix-socket")
.and_then(|v| v.to_str().ok())
.map(ToString::to_string)
{
// `x-unix-socket` is an escape hatch for talking to local daemons. We keep it tightly scoped:
// macOS-only + explicit allowlist, to avoid turning the proxy into a general local capability
// escalation mechanism.
if let Some(unix_socket_header) = req.headers().get("x-unix-socket") {
let socket_path = match unix_socket_header.to_str() {
Ok(value) => value.to_string(),
Err(_) => {
warn!("invalid x-unix-socket header value (non-UTF8)");
return Ok(text_response(
StatusCode::BAD_REQUEST,
"invalid x-unix-socket header",
));
}
};
if !method_allowed {
let client = client.as_deref().unwrap_or_default();
let method = req.method();
@@ -338,11 +355,14 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
}
}
async fn proxy_via_unix_socket(
ctx: ProxyContext,
async fn proxy_via_unix_socket<S>(
ctx: Context<S>,
req: Request,
socket_path: &str,
) -> Result<Response> {
) -> Result<Response>
where
S: Clone + Send + Sync + 'static,
{
#[cfg(target_os = "macos")]
{
use rama::unix::client::UnixConnector;
@@ -358,7 +378,7 @@ async fn proxy_via_unix_socket(
let path = parts
.uri
.path_and_query()
.map(PathAndQuery::as_str)
.map(rama::http::dep::http::uri::PathAndQuery::as_str)
.unwrap_or("/");
parts.uri = path
.parse()
@@ -370,14 +390,14 @@ async fn proxy_via_unix_socket(
}
#[cfg(not(target_os = "macos"))]
{
let _ = ctx;
let _ = req;
let _ = ctx;
let _ = socket_path;
Err(anyhow::anyhow!("unix sockets not supported"))
}
}
fn client_addr(ctx: &ProxyContext) -> Option<String> {
fn client_addr<S>(ctx: &Context<S>) -> Option<String> {
ctx.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string())
}

View File

@@ -5,10 +5,10 @@ use crate::policy::normalize_host;
use crate::responses::blocked_text_response;
use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Context;
use anyhow::Context as _;
use anyhow::Result;
use anyhow::anyhow;
use rama::Context as RamaContext;
use rama::Context;
use rama::Layer;
use rama::Service;
use rama::bytes::Bytes;
@@ -21,7 +21,7 @@ use rama::http::Request;
use rama::http::Response;
use rama::http::StatusCode;
use rama::http::Uri;
use rama::http::dep::http::uri::PathAndQuery;
use rama::http::client::EasyHttpWebClient;
use rama::http::header::HOST;
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
@@ -30,14 +30,15 @@ use rama::http::server::HttpServer;
use rama::net::proxy::ProxyTarget;
use rama::net::stream::SocketInfo;
use rama::service::service_fn;
use rama::tls::rustls::dep::pemfile;
use rama::tls::rustls::dep::pki_types::CertificateDer;
use rama::tls::rustls::dep::pki_types::PrivateKeyDer;
use rama::tls::rustls::dep::pki_types::pem::PemObject;
use rama::tls::rustls::server::TlsAcceptorData;
use rama::tls::rustls::server::TlsAcceptorDataBuilder;
use rama::tls::rustls::server::TlsAcceptorLayer;
use std::fs;
use std::fs::File;
use std::fs::OpenOptions;
use std::io::BufReader;
use std::io::Write;
use std::net::IpAddr;
use std::pin::Pin;
@@ -62,11 +63,21 @@ use rcgen_rama::SanType;
pub struct MitmState {
issuer: Issuer<'static, KeyPair>,
upstream: rama::service::BoxService<Arc<AppState>, Request, Response, OpaqueError>,
upstream: rama::service::BoxService<(), Request, Response, OpaqueError>,
inspect: bool,
max_body_bytes: usize,
}
impl std::fmt::Debug for MitmState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Avoid dumping internal state (CA material, connectors, etc.) to logs.
f.debug_struct("MitmState")
.field("inspect", &self.inspect)
.field("max_body_bytes", &self.max_body_bytes)
.finish_non_exhaustive()
}
}
impl MitmState {
pub fn new(cfg: &MitmConfig) -> Result<Self> {
// MITM exists to make limited-mode HTTPS enforceable: once CONNECT is established, plain
@@ -79,14 +90,15 @@ impl MitmState {
let tls_config = rama::tls::rustls::client::TlsConnectorData::new_http_auto()
.context("create upstream TLS config")?;
let upstream = rama::http::client::EasyHttpWebClient::builder()
// Use a direct transport connector (no upstream proxy) to avoid proxy loops.
.with_default_transport_connector()
.without_tls_proxy_support()
.without_proxy_support()
.with_tls_support_using_rustls(Some(tls_config))
.build()
.boxed();
let upstream: rama::service::BoxService<(), Request, Response, OpaqueError> =
EasyHttpWebClient::builder()
// Use a direct transport connector (no upstream proxy) to avoid proxy loops.
.with_default_transport_connector()
.without_tls_proxy_support()
.without_proxy_support()
.with_tls_support_using_rustls(Some(tls_config))
.build()
.boxed();
Ok(Self {
issuer,
@@ -98,16 +110,15 @@ impl MitmState {
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
let (cert_pem, key_pem) = issue_host_certificate_pem(host, &self.issuer)?;
let cert_chain = pemfile::certs(&mut BufReader::new(cert_pem.as_bytes()))
let cert_chain = CertificateDer::pem_slice_iter(cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()
.context("failed to parse host cert PEM")?;
if cert_chain.is_empty() {
return Err(anyhow!("no certificates found"));
}
let key_der = pemfile::private_key(&mut BufReader::new(key_pem.as_bytes()))
.context("failed to parse host key PEM")?
.context("no private key found")?;
let key_der = PrivateKeyDer::from_pem_slice(key_pem.as_bytes())
.context("failed to parse host key PEM")?;
Ok(TlsAcceptorDataBuilder::new(cert_chain, key_der)
.context("failed to build rustls acceptor config")?
@@ -124,20 +135,25 @@ impl MitmState {
}
}
pub async fn mitm_tunnel(
mut ctx: RamaContext<Arc<AppState>>,
upgraded: Upgraded,
host: &str,
_port: u16,
mode: NetworkMode,
state: Arc<MitmState>,
) -> Result<()> {
// Ensure the MITM state is available for the per-request handler.
ctx.insert(state.clone());
ctx.insert(mode);
pub async fn mitm_tunnel<S>(ctx: Context<S>, upgraded: Upgraded) -> Result<()>
where
S: Clone + Send + Sync + 'static,
{
let state = ctx
.get::<Arc<MitmState>>()
.cloned()
.context("missing MITM state")?;
let target = ctx
.get::<ProxyTarget>()
.context("missing proxy target")?
.0
.clone();
let host = normalize_host(&target.host().to_string());
let acceptor_data = state.tls_acceptor_data_for_host(&host)?;
let acceptor_data = state.tls_acceptor_data_for_host(host)?;
let http_service = HttpServer::auto(ctx.executor().clone()).service(
let executor = ctx.executor().clone();
let http_service = HttpServer::auto(executor).service(
(
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
@@ -156,10 +172,13 @@ pub async fn mitm_tunnel(
Ok(())
}
async fn handle_mitm_request(
ctx: RamaContext<Arc<AppState>>,
async fn handle_mitm_request<S>(
ctx: Context<S>,
req: Request,
) -> Result<Response, std::convert::Infallible> {
) -> Result<Response, std::convert::Infallible>
where
S: Clone + Send + Sync + 'static,
{
let response = match forward_request(ctx, req).await {
Ok(resp) => resp,
Err(err) => {
@@ -170,7 +189,10 @@ async fn handle_mitm_request(
Ok(response)
}
async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Result<Response> {
async fn forward_request<S>(ctx: Context<S>, req: Request) -> Result<Response>
where
S: Clone + Send + Sync + 'static,
{
let target = ctx
.get::<ProxyTarget>()
.context("missing proxy target")?
@@ -187,6 +209,10 @@ async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Resul
.get::<Arc<MitmState>>()
.cloned()
.context("missing MITM state")?;
let app_state = ctx
.get::<Arc<AppState>>()
.cloned()
.context("missing app state")?;
if req.method().as_str() == "CONNECT" {
return Ok(text_response(
@@ -210,8 +236,7 @@ async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Resul
}
if !method_allowed(mode, method.as_str()) {
let _ = ctx
.state()
let _ = app_state
.record_blocked(BlockedRequest::new(
target_host.clone(),
"method_not_allowed".to_string(),
@@ -251,7 +276,10 @@ async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Resul
};
let upstream_req = Request::from_parts(parts, body);
let upstream_resp = mitm.upstream.serve(ctx, upstream_req).await?;
let upstream_resp = mitm
.upstream
.serve(ctx.map_state(|_| ()), upstream_req)
.await?;
respond_with_inspection(
upstream_resp,
inspect,
@@ -400,7 +428,7 @@ fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
fn path_and_query(uri: &Uri) -> String {
uri.path_and_query()
.map(PathAndQuery::as_str)
.map(rama::http::dep::http::uri::PathAndQuery::as_str)
.unwrap_or("/")
.to_string()
}
@@ -519,22 +547,41 @@ fn write_atomic_create_new(path: &std::path::Path, contents: &[u8], mode: u32) -
.with_context(|| format!("failed to fsync {}", tmp_path.display()))?;
drop(file);
if path.exists() {
let _ = fs::remove_file(&tmp_path);
return Err(anyhow!(
"refusing to overwrite existing file {}",
path.display()
));
// Create the final file using "create-new" semantics (no overwrite). `rename` on Unix can
// overwrite existing files, so prefer a hard-link, which fails if the destination exists.
match fs::hard_link(&tmp_path, path) {
Ok(()) => {
fs::remove_file(&tmp_path)
.with_context(|| format!("failed to remove {}", tmp_path.display()))?;
}
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
let _ = fs::remove_file(&tmp_path);
return Err(anyhow!(
"refusing to overwrite existing file {}",
path.display()
));
}
Err(_) => {
// Best-effort fallback for environments where hard links are not supported.
// This is still subject to a TOCTOU race, but the typical case is a private per-user
// config directory, where other users cannot create files anyway.
if path.exists() {
let _ = fs::remove_file(&tmp_path);
return Err(anyhow!(
"refusing to overwrite existing file {}",
path.display()
));
}
fs::rename(&tmp_path, path).with_context(|| {
format!(
"failed to rename {} -> {}",
tmp_path.display(),
path.display()
)
})?;
}
}
fs::rename(&tmp_path, path).with_context(|| {
format!(
"failed to rename {} -> {}",
tmp_path.display(),
path.display()
)
})?;
// Best-effort durability: ensure the directory entry is persisted too.
let dir = File::open(parent).with_context(|| format!("failed to open {}", parent.display()))?;
dir.sync_all()

View File

@@ -2,10 +2,12 @@ use crate::config::NetworkMode;
use crate::policy::normalize_host;
use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Context as _;
use anyhow::Result;
use anyhow::anyhow;
use rama::Context as RamaContext;
use rama::Context;
use rama::Layer;
use rama::Service;
use rama::layer::AddExtensionLayer;
use rama::net::stream::SocketInfo;
use rama::proxy::socks5::Socks5Acceptor;
use rama::proxy::socks5::server::DefaultConnector;
@@ -21,10 +23,13 @@ use tracing::info;
use tracing::warn;
pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::build_with_state(state.clone())
let listener = TcpListener::build()
.bind(addr)
.await
.map_err(|err| anyhow!("bind SOCKS5 proxy: {err}"))?;
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
.map_err(rama::error::OpaqueError::from)
.map_err(anyhow::Error::from)
.with_context(|| format!("bind SOCKS5 proxy: {addr}"))?;
info!("SOCKS5 proxy listening on {addr}");
@@ -39,81 +44,80 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
}
let tcp_connector = TcpConnector::default();
let policy_tcp_connector = service_fn(
move |ctx: RamaContext<Arc<AppState>>, req: TcpRequest| {
let tcp_connector = tcp_connector.clone();
async move {
let app_state = ctx.state().clone();
let authority = req.authority().clone();
let host = normalize_host(&authority.host().to_string());
let port = authority.port();
let client = ctx
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
let policy_tcp_connector = service_fn(move |ctx: Context<()>, req: TcpRequest| {
let tcp_connector = tcp_connector.clone();
async move {
let app_state = ctx
.get::<Arc<AppState>>()
.cloned()
.ok_or_else(|| io::Error::other("missing state"))?;
match app_state.network_mode().await {
Ok(NetworkMode::Limited) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"method_not_allowed".to_string(),
client.clone(),
None,
Some(NetworkMode::Limited),
"socks5".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!(
"SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Err(
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
);
}
Ok(NetworkMode::Full) => {}
Err(err) => {
error!("failed to evaluate method policy: {err}");
return Err(io::Error::other("proxy error").into());
}
let host = normalize_host(&req.authority().host().to_string());
let port = req.authority().port();
let client = ctx
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
match app_state.network_mode().await {
Ok(NetworkMode::Limited) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"method_not_allowed".to_string(),
client.clone(),
None,
Some(NetworkMode::Limited),
"socks5".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!(
"SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Err(io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into());
}
match app_state.host_blocked(&host, port).await {
Ok((true, reason)) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
None,
None,
"socks5".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("SOCKS blocked (client={client}, host={host}, reason={reason})");
return Err(
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
);
}
Ok((false, _)) => {
let client = client.as_deref().unwrap_or_default();
info!("SOCKS allowed (client={client}, host={host}, port={port})");
}
Err(err) => {
error!("failed to evaluate host: {err}");
return Err(io::Error::other("proxy error").into());
}
Ok(NetworkMode::Full) => {}
Err(err) => {
error!("failed to evaluate method policy: {err}");
return Err(io::Error::other("proxy error").into());
}
tcp_connector.serve(ctx, req).await
}
},
);
match app_state.host_blocked(&host, port).await {
Ok((true, reason)) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
None,
None,
"socks5".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("SOCKS blocked (client={client}, host={host}, reason={reason})");
return Err(io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into());
}
Ok((false, _)) => {
let client = client.as_deref().unwrap_or_default();
info!("SOCKS allowed (client={client}, host={host}, port={port})");
}
Err(err) => {
error!("failed to evaluate host: {err}");
return Err(io::Error::other("proxy error").into());
}
}
tcp_connector.serve(ctx, req).await
}
});
let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector);
let socks_acceptor = Socks5Acceptor::new().with_connector(socks_connector);
listener.serve(socks_acceptor).await;
listener
.serve(AddExtensionLayer::new(state).into_layer(socks_acceptor))
.await;
Ok(())
}

View File

@@ -12,6 +12,7 @@ use codex_core::config::CONFIG_TOML_FILE;
use codex_core::config::ConfigBuilder;
use codex_core::config::Constrained;
use codex_core::config::ConstraintError;
use codex_core::config_loader::RequirementSource;
use globset::GlobBuilder;
use globset::GlobSet;
use globset::GlobSetBuilder;
@@ -80,6 +81,14 @@ pub struct AppState {
state: Arc<RwLock<ConfigState>>,
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Avoid logging internal state (config contents, derived globsets, etc.) which can be noisy
// and may contain sensitive paths.
f.debug_struct("AppState").finish_non_exhaustive()
}
}
impl AppState {
pub async fn new() -> Result<Self> {
let cfg_state = build_config_state().await?;
@@ -467,11 +476,25 @@ fn validate_policy_against_constraints(
config: &Config,
constraints: &NetworkProxyConstraints,
) -> std::result::Result<(), ConstraintError> {
fn invalid_value(
field_name: &'static str,
candidate: impl Into<String>,
allowed: impl Into<String>,
) -> ConstraintError {
ConstraintError::InvalidValue {
field_name,
candidate: candidate.into(),
allowed: allowed.into(),
requirement_source: RequirementSource::Unknown,
}
}
let enabled = config.network_proxy.enabled;
if let Some(max_enabled) = constraints.enabled {
let _ = Constrained::new(enabled, move |candidate| {
if *candidate && !max_enabled {
Err(ConstraintError::invalid_value(
Err(invalid_value(
"network_proxy.enabled",
"true",
"false (disabled by managed config)",
))
@@ -484,7 +507,8 @@ fn validate_policy_against_constraints(
if let Some(max_mode) = constraints.mode {
let _ = Constrained::new(config.network_proxy.mode, move |candidate| {
if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
Err(ConstraintError::invalid_value(
Err(invalid_value(
"network_proxy.mode",
format!("{candidate:?}"),
format!("{max_mode:?} or more restrictive"),
))
@@ -501,7 +525,8 @@ fn validate_policy_against_constraints(
Some(true) | None => Ok(()),
Some(false) => {
if *candidate {
Err(ConstraintError::invalid_value(
Err(invalid_value(
"network_proxy.dangerously_allow_non_loopback_admin",
"true",
"false (disabled by managed config)",
))
@@ -519,7 +544,8 @@ fn validate_policy_against_constraints(
Some(true) | None => Ok(()),
Some(false) => {
if *candidate {
Err(ConstraintError::invalid_value(
Err(invalid_value(
"network_proxy.dangerously_allow_non_loopback_proxy",
"true",
"false (disabled by managed config)",
))
@@ -535,7 +561,8 @@ fn validate_policy_against_constraints(
config.network_proxy.policy.allow_local_binding,
move |candidate| {
if *candidate && !allow_local_binding {
Err(ConstraintError::invalid_value(
Err(invalid_value(
"network_proxy.policy.allow_local_binding",
"true",
"false (disabled by managed config)",
))
@@ -563,7 +590,8 @@ fn validate_policy_against_constraints(
if invalid.is_empty() {
Ok(())
} else {
Err(ConstraintError::invalid_value(
Err(invalid_value(
"network_proxy.policy.allowed_domains",
format!("{invalid:?}"),
"subset of managed allowed_domains",
))
@@ -590,7 +618,8 @@ fn validate_policy_against_constraints(
if missing.is_empty() {
Ok(())
} else {
Err(ConstraintError::invalid_value(
Err(invalid_value(
"network_proxy.policy.denied_domains",
"missing managed denied_domains entries",
format!("{missing:?}"),
))
@@ -616,7 +645,8 @@ fn validate_policy_against_constraints(
if invalid.is_empty() {
Ok(())
} else {
Err(ConstraintError::invalid_value(
Err(invalid_value(
"network_proxy.policy.allow_unix_sockets",
format!("{invalid:?}"),
"subset of managed allow_unix_sockets",
))
@@ -766,7 +796,9 @@ mod tests {
(false, String::new())
);
assert_eq!(
state.host_blocked("not-example.com", 80).await.unwrap(),
// Use a public IP literal to avoid relying on ambient DNS behavior (some networks
// resolve unknown hostnames to private IPs, which would trigger `not_allowed_local`).
state.host_blocked("8.8.8.8", 80).await.unwrap(),
(true, "not_allowed".to_string())
);
}