mirror of
https://github.com/openai/codex.git
synced 2026-04-30 01:16:54 +00:00
Update network proxy rama deps
This commit is contained in:
@@ -2,8 +2,8 @@ use crate::config::NetworkMode;
|
||||
use crate::responses::json_response;
|
||||
use crate::responses::text_response;
|
||||
use crate::state::AppState;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::http::Body;
|
||||
use rama::http::Request;
|
||||
use rama::http::Response;
|
||||
@@ -19,28 +19,30 @@ use std::sync::Arc;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
|
||||
type ContextState = Arc<AppState>;
|
||||
type AdminContext = RamaContext<ContextState>;
|
||||
|
||||
pub async fn run_admin_api(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
// Debug-only admin API (health/config/patterns/blocked + mode/reload). Policy is config-driven
|
||||
// and constraint-enforced; this endpoint should not become a second policy/approval plane.
|
||||
let listener = TcpListener::build_with_state(state)
|
||||
let listener = TcpListener::build()
|
||||
.bind(addr)
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!("bind admin API: {err}"))?;
|
||||
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
|
||||
.map_err(rama::error::OpaqueError::from)
|
||||
.map_err(anyhow::Error::from)
|
||||
.with_context(|| format!("bind admin API: {addr}"))?;
|
||||
|
||||
let server =
|
||||
HttpServer::auto(rama::rt::Executor::new()).service(service_fn(handle_admin_request));
|
||||
let server_state = state.clone();
|
||||
let server = HttpServer::auto(rama::rt::Executor::new()).service(service_fn(move |req| {
|
||||
let state = server_state.clone();
|
||||
async move { handle_admin_request(state, req).await }
|
||||
}));
|
||||
info!("admin API listening on {addr}");
|
||||
listener.serve(server).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_admin_request(ctx: AdminContext, req: Request) -> Result<Response, Infallible> {
|
||||
async fn handle_admin_request(state: Arc<AppState>, req: Request) -> Result<Response, Infallible> {
|
||||
const MODE_BODY_LIMIT: usize = 8 * 1024;
|
||||
|
||||
let state = ctx.state().clone();
|
||||
let method = req.method().clone();
|
||||
let path = req.uri().path().to_string();
|
||||
let response = match (method.as_str(), path.as_str()) {
|
||||
|
||||
@@ -341,20 +341,23 @@ mod tests {
|
||||
fn resolve_addr_maps_localhost_to_loopback() {
|
||||
assert_eq!(
|
||||
resolve_addr("localhost", 3128),
|
||||
"127.0.0.1:3128".parse().unwrap()
|
||||
"127.0.0.1:3128".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_parses_ip_literals() {
|
||||
assert_eq!(resolve_addr("1.2.3.4", 80), "1.2.3.4:80".parse().unwrap());
|
||||
assert_eq!(
|
||||
resolve_addr("1.2.3.4", 80),
|
||||
"1.2.3.4:80".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_parses_ipv6_literals() {
|
||||
assert_eq!(
|
||||
resolve_addr("http://[::1]:8080", 3128),
|
||||
"[::1]:8080".parse().unwrap()
|
||||
"[::1]:8080".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -362,7 +365,7 @@ mod tests {
|
||||
fn resolve_addr_falls_back_to_loopback_for_hostnames() {
|
||||
assert_eq!(
|
||||
resolve_addr("http://example.com:5555", 3128),
|
||||
"127.0.0.1:5555".parse().unwrap()
|
||||
"127.0.0.1:5555".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ use crate::policy::normalize_host;
|
||||
use crate::responses::blocked_header_value;
|
||||
use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use anyhow::Context;
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::Context;
|
||||
use rama::Layer;
|
||||
use rama::Service;
|
||||
use rama::http::Body;
|
||||
@@ -14,13 +14,13 @@ use rama::http::Request;
|
||||
use rama::http::Response;
|
||||
use rama::http::StatusCode;
|
||||
use rama::http::client::EasyHttpWebClient;
|
||||
use rama::http::dep::http::uri::PathAndQuery;
|
||||
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
use rama::http::layer::upgrade::UpgradeLayer;
|
||||
use rama::http::layer::upgrade::Upgraded;
|
||||
use rama::http::matcher::MethodMatcher;
|
||||
use rama::http::server::HttpServer;
|
||||
use rama::layer::AddExtensionLayer;
|
||||
use rama::net::http::RequestContext;
|
||||
use rama::net::proxy::ProxyTarget;
|
||||
use rama::net::stream::SocketInfo;
|
||||
@@ -35,14 +35,16 @@ use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
type ContextState = Arc<AppState>;
|
||||
type ProxyContext = RamaContext<ContextState>;
|
||||
|
||||
pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
let listener = TcpListener::build_with_state(state)
|
||||
let listener = TcpListener::build()
|
||||
.bind(addr)
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!(err))
|
||||
// Rama's `BoxError` is a `Box<dyn Error + Send + Sync>` without an explicit `'static`
|
||||
// lifetime bound, which means it doesn't satisfy `anyhow::Context`'s `StdError` constraint.
|
||||
// Wrap it in Rama's `OpaqueError` so we can preserve the original error as a source and
|
||||
// still use `anyhow` for chaining.
|
||||
.map_err(rama::error::OpaqueError::from)
|
||||
.map_err(anyhow::Error::from)
|
||||
.with_context(|| format!("bind HTTP proxy: {addr}"))?;
|
||||
|
||||
let http_service = HttpServer::auto(rama::rt::Executor::new()).service(
|
||||
@@ -60,14 +62,24 @@ pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()
|
||||
|
||||
info!("HTTP proxy listening on {addr}");
|
||||
|
||||
listener.serve(http_service).await;
|
||||
listener
|
||||
.serve(AddExtensionLayer::new(state).into_layer(http_service))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn http_connect_accept(
|
||||
mut ctx: ProxyContext,
|
||||
async fn http_connect_accept<S>(
|
||||
mut ctx: Context<S>,
|
||||
req: Request,
|
||||
) -> Result<(Response, ProxyContext, Request), Response> {
|
||||
) -> Result<(Response, Context<S>, Request), Response>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
let app_state = ctx
|
||||
.get::<Arc<AppState>>()
|
||||
.cloned()
|
||||
.ok_or_else(|| text_response(StatusCode::INTERNAL_SERVER_ERROR, "missing state"))?;
|
||||
|
||||
let authority = match ctx
|
||||
.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
|
||||
.map(|ctx| ctx.authority.clone())
|
||||
@@ -84,7 +96,6 @@ async fn http_connect_accept(
|
||||
return Err(text_response(StatusCode::BAD_REQUEST, "invalid host"));
|
||||
}
|
||||
|
||||
let app_state = ctx.state().clone();
|
||||
let client = client_addr(&ctx);
|
||||
|
||||
match app_state.host_blocked(&host, authority.port()).await {
|
||||
@@ -165,33 +176,25 @@ async fn http_connect_accept(
|
||||
))
|
||||
}
|
||||
|
||||
async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(), Infallible> {
|
||||
async fn http_connect_proxy<S>(ctx: Context<S>, upgraded: Upgraded) -> Result<(), Infallible>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
let mode = ctx
|
||||
.get::<NetworkMode>()
|
||||
.copied()
|
||||
.unwrap_or(NetworkMode::Full);
|
||||
let authority = match ctx.get::<ProxyTarget>().map(|target| target.0.clone()) {
|
||||
Some(authority) => authority,
|
||||
None => {
|
||||
warn!("CONNECT missing proxy target");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let host = normalize_host(&authority.host().to_string());
|
||||
|
||||
if let Some(mitm_state) = ctx.get::<Arc<mitm::MitmState>>().cloned() {
|
||||
let port = authority.port();
|
||||
let Some(target) = ctx.get::<ProxyTarget>().map(|t| t.0.clone()) else {
|
||||
warn!("CONNECT missing proxy target");
|
||||
return Ok(());
|
||||
};
|
||||
let host = normalize_host(&target.host().to_string());
|
||||
|
||||
if ctx.get::<Arc<mitm::MitmState>>().is_some() {
|
||||
let port = target.port();
|
||||
info!("CONNECT MITM enabled (host={host}, port={port}, mode={mode:?})");
|
||||
if let Err(err) = mitm::mitm_tunnel(
|
||||
ctx,
|
||||
upgraded,
|
||||
host.as_str(),
|
||||
authority.port(),
|
||||
mode,
|
||||
mitm_state,
|
||||
)
|
||||
.await
|
||||
{
|
||||
if let Err(err) = mitm::mitm_tunnel(ctx, upgraded).await {
|
||||
warn!("MITM tunnel error: {err}");
|
||||
}
|
||||
return Ok(());
|
||||
@@ -204,8 +207,17 @@ async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(),
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Response, Infallible> {
|
||||
let app_state = ctx.state().clone();
|
||||
async fn http_plain_proxy<S>(mut ctx: Context<S>, req: Request) -> Result<Response, Infallible>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
let app_state = match ctx.get::<Arc<AppState>>().cloned() {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
error!("missing app state");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
let client = client_addr(&ctx);
|
||||
|
||||
let method_allowed = match app_state.method_allowed(req.method().as_str()).await {
|
||||
@@ -216,15 +228,20 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(socket_path) = req
|
||||
// `x-unix-socket` is an escape hatch for talking to local daemons. We keep it tightly
|
||||
// scoped: macOS-only + explicit allowlist, to avoid turning the proxy into a general local
|
||||
// capability-escalation mechanism.
|
||||
.headers()
|
||||
.get("x-unix-socket")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(ToString::to_string)
|
||||
{
|
||||
// `x-unix-socket` is an escape hatch for talking to local daemons. We keep it tightly scoped:
|
||||
// macOS-only + explicit allowlist, to avoid turning the proxy into a general local capability
|
||||
// escalation mechanism.
|
||||
if let Some(unix_socket_header) = req.headers().get("x-unix-socket") {
|
||||
let socket_path = match unix_socket_header.to_str() {
|
||||
Ok(value) => value.to_string(),
|
||||
Err(_) => {
|
||||
warn!("invalid x-unix-socket header value (non-UTF8)");
|
||||
return Ok(text_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid x-unix-socket header",
|
||||
));
|
||||
}
|
||||
};
|
||||
if !method_allowed {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
@@ -338,11 +355,14 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_via_unix_socket(
|
||||
ctx: ProxyContext,
|
||||
async fn proxy_via_unix_socket<S>(
|
||||
ctx: Context<S>,
|
||||
req: Request,
|
||||
socket_path: &str,
|
||||
) -> Result<Response> {
|
||||
) -> Result<Response>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
use rama::unix::client::UnixConnector;
|
||||
@@ -358,7 +378,7 @@ async fn proxy_via_unix_socket(
|
||||
let path = parts
|
||||
.uri
|
||||
.path_and_query()
|
||||
.map(PathAndQuery::as_str)
|
||||
.map(rama::http::dep::http::uri::PathAndQuery::as_str)
|
||||
.unwrap_or("/");
|
||||
parts.uri = path
|
||||
.parse()
|
||||
@@ -370,14 +390,14 @@ async fn proxy_via_unix_socket(
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _ = ctx;
|
||||
let _ = req;
|
||||
let _ = ctx;
|
||||
let _ = socket_path;
|
||||
Err(anyhow::anyhow!("unix sockets not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
fn client_addr(ctx: &ProxyContext) -> Option<String> {
|
||||
fn client_addr<S>(ctx: &Context<S>) -> Option<String> {
|
||||
ctx.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string())
|
||||
}
|
||||
|
||||
@@ -5,10 +5,10 @@ use crate::policy::normalize_host;
|
||||
use crate::responses::blocked_text_response;
|
||||
use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use anyhow::Context;
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::Context;
|
||||
use rama::Layer;
|
||||
use rama::Service;
|
||||
use rama::bytes::Bytes;
|
||||
@@ -21,7 +21,7 @@ use rama::http::Request;
|
||||
use rama::http::Response;
|
||||
use rama::http::StatusCode;
|
||||
use rama::http::Uri;
|
||||
use rama::http::dep::http::uri::PathAndQuery;
|
||||
use rama::http::client::EasyHttpWebClient;
|
||||
use rama::http::header::HOST;
|
||||
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
@@ -30,14 +30,15 @@ use rama::http::server::HttpServer;
|
||||
use rama::net::proxy::ProxyTarget;
|
||||
use rama::net::stream::SocketInfo;
|
||||
use rama::service::service_fn;
|
||||
use rama::tls::rustls::dep::pemfile;
|
||||
use rama::tls::rustls::dep::pki_types::CertificateDer;
|
||||
use rama::tls::rustls::dep::pki_types::PrivateKeyDer;
|
||||
use rama::tls::rustls::dep::pki_types::pem::PemObject;
|
||||
use rama::tls::rustls::server::TlsAcceptorData;
|
||||
use rama::tls::rustls::server::TlsAcceptorDataBuilder;
|
||||
use rama::tls::rustls::server::TlsAcceptorLayer;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::BufReader;
|
||||
use std::io::Write;
|
||||
use std::net::IpAddr;
|
||||
use std::pin::Pin;
|
||||
@@ -62,11 +63,21 @@ use rcgen_rama::SanType;
|
||||
|
||||
pub struct MitmState {
|
||||
issuer: Issuer<'static, KeyPair>,
|
||||
upstream: rama::service::BoxService<Arc<AppState>, Request, Response, OpaqueError>,
|
||||
upstream: rama::service::BoxService<(), Request, Response, OpaqueError>,
|
||||
inspect: bool,
|
||||
max_body_bytes: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MitmState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Avoid dumping internal state (CA material, connectors, etc.) to logs.
|
||||
f.debug_struct("MitmState")
|
||||
.field("inspect", &self.inspect)
|
||||
.field("max_body_bytes", &self.max_body_bytes)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl MitmState {
|
||||
pub fn new(cfg: &MitmConfig) -> Result<Self> {
|
||||
// MITM exists to make limited-mode HTTPS enforceable: once CONNECT is established, plain
|
||||
@@ -79,14 +90,15 @@ impl MitmState {
|
||||
|
||||
let tls_config = rama::tls::rustls::client::TlsConnectorData::new_http_auto()
|
||||
.context("create upstream TLS config")?;
|
||||
let upstream = rama::http::client::EasyHttpWebClient::builder()
|
||||
// Use a direct transport connector (no upstream proxy) to avoid proxy loops.
|
||||
.with_default_transport_connector()
|
||||
.without_tls_proxy_support()
|
||||
.without_proxy_support()
|
||||
.with_tls_support_using_rustls(Some(tls_config))
|
||||
.build()
|
||||
.boxed();
|
||||
let upstream: rama::service::BoxService<(), Request, Response, OpaqueError> =
|
||||
EasyHttpWebClient::builder()
|
||||
// Use a direct transport connector (no upstream proxy) to avoid proxy loops.
|
||||
.with_default_transport_connector()
|
||||
.without_tls_proxy_support()
|
||||
.without_proxy_support()
|
||||
.with_tls_support_using_rustls(Some(tls_config))
|
||||
.build()
|
||||
.boxed();
|
||||
|
||||
Ok(Self {
|
||||
issuer,
|
||||
@@ -98,16 +110,15 @@ impl MitmState {
|
||||
|
||||
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
|
||||
let (cert_pem, key_pem) = issue_host_certificate_pem(host, &self.issuer)?;
|
||||
let cert_chain = pemfile::certs(&mut BufReader::new(cert_pem.as_bytes()))
|
||||
let cert_chain = CertificateDer::pem_slice_iter(cert_pem.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.context("failed to parse host cert PEM")?;
|
||||
if cert_chain.is_empty() {
|
||||
return Err(anyhow!("no certificates found"));
|
||||
}
|
||||
|
||||
let key_der = pemfile::private_key(&mut BufReader::new(key_pem.as_bytes()))
|
||||
.context("failed to parse host key PEM")?
|
||||
.context("no private key found")?;
|
||||
let key_der = PrivateKeyDer::from_pem_slice(key_pem.as_bytes())
|
||||
.context("failed to parse host key PEM")?;
|
||||
|
||||
Ok(TlsAcceptorDataBuilder::new(cert_chain, key_der)
|
||||
.context("failed to build rustls acceptor config")?
|
||||
@@ -124,20 +135,25 @@ impl MitmState {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mitm_tunnel(
|
||||
mut ctx: RamaContext<Arc<AppState>>,
|
||||
upgraded: Upgraded,
|
||||
host: &str,
|
||||
_port: u16,
|
||||
mode: NetworkMode,
|
||||
state: Arc<MitmState>,
|
||||
) -> Result<()> {
|
||||
// Ensure the MITM state is available for the per-request handler.
|
||||
ctx.insert(state.clone());
|
||||
ctx.insert(mode);
|
||||
pub async fn mitm_tunnel<S>(ctx: Context<S>, upgraded: Upgraded) -> Result<()>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
let state = ctx
|
||||
.get::<Arc<MitmState>>()
|
||||
.cloned()
|
||||
.context("missing MITM state")?;
|
||||
let target = ctx
|
||||
.get::<ProxyTarget>()
|
||||
.context("missing proxy target")?
|
||||
.0
|
||||
.clone();
|
||||
let host = normalize_host(&target.host().to_string());
|
||||
let acceptor_data = state.tls_acceptor_data_for_host(&host)?;
|
||||
|
||||
let acceptor_data = state.tls_acceptor_data_for_host(host)?;
|
||||
let http_service = HttpServer::auto(ctx.executor().clone()).service(
|
||||
let executor = ctx.executor().clone();
|
||||
|
||||
let http_service = HttpServer::auto(executor).service(
|
||||
(
|
||||
RemoveResponseHeaderLayer::hop_by_hop(),
|
||||
RemoveRequestHeaderLayer::hop_by_hop(),
|
||||
@@ -156,10 +172,13 @@ pub async fn mitm_tunnel(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_mitm_request(
|
||||
ctx: RamaContext<Arc<AppState>>,
|
||||
async fn handle_mitm_request<S>(
|
||||
ctx: Context<S>,
|
||||
req: Request,
|
||||
) -> Result<Response, std::convert::Infallible> {
|
||||
) -> Result<Response, std::convert::Infallible>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
let response = match forward_request(ctx, req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
@@ -170,7 +189,10 @@ async fn handle_mitm_request(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Result<Response> {
|
||||
async fn forward_request<S>(ctx: Context<S>, req: Request) -> Result<Response>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
let target = ctx
|
||||
.get::<ProxyTarget>()
|
||||
.context("missing proxy target")?
|
||||
@@ -187,6 +209,10 @@ async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Resul
|
||||
.get::<Arc<MitmState>>()
|
||||
.cloned()
|
||||
.context("missing MITM state")?;
|
||||
let app_state = ctx
|
||||
.get::<Arc<AppState>>()
|
||||
.cloned()
|
||||
.context("missing app state")?;
|
||||
|
||||
if req.method().as_str() == "CONNECT" {
|
||||
return Ok(text_response(
|
||||
@@ -210,8 +236,7 @@ async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Resul
|
||||
}
|
||||
|
||||
if !method_allowed(mode, method.as_str()) {
|
||||
let _ = ctx
|
||||
.state()
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
target_host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
@@ -251,7 +276,10 @@ async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Resul
|
||||
};
|
||||
|
||||
let upstream_req = Request::from_parts(parts, body);
|
||||
let upstream_resp = mitm.upstream.serve(ctx, upstream_req).await?;
|
||||
let upstream_resp = mitm
|
||||
.upstream
|
||||
.serve(ctx.map_state(|_| ()), upstream_req)
|
||||
.await?;
|
||||
respond_with_inspection(
|
||||
upstream_resp,
|
||||
inspect,
|
||||
@@ -400,7 +428,7 @@ fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
|
||||
|
||||
fn path_and_query(uri: &Uri) -> String {
|
||||
uri.path_and_query()
|
||||
.map(PathAndQuery::as_str)
|
||||
.map(rama::http::dep::http::uri::PathAndQuery::as_str)
|
||||
.unwrap_or("/")
|
||||
.to_string()
|
||||
}
|
||||
@@ -519,22 +547,41 @@ fn write_atomic_create_new(path: &std::path::Path, contents: &[u8], mode: u32) -
|
||||
.with_context(|| format!("failed to fsync {}", tmp_path.display()))?;
|
||||
drop(file);
|
||||
|
||||
if path.exists() {
|
||||
let _ = fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"refusing to overwrite existing file {}",
|
||||
path.display()
|
||||
));
|
||||
// Create the final file using "create-new" semantics (no overwrite). `rename` on Unix can
|
||||
// overwrite existing files, so prefer a hard-link, which fails if the destination exists.
|
||||
match fs::hard_link(&tmp_path, path) {
|
||||
Ok(()) => {
|
||||
fs::remove_file(&tmp_path)
|
||||
.with_context(|| format!("failed to remove {}", tmp_path.display()))?;
|
||||
}
|
||||
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
|
||||
let _ = fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"refusing to overwrite existing file {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
// Best-effort fallback for environments where hard links are not supported.
|
||||
// This is still subject to a TOCTOU race, but the typical case is a private per-user
|
||||
// config directory, where other users cannot create files anyway.
|
||||
if path.exists() {
|
||||
let _ = fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"refusing to overwrite existing file {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
fs::rename(&tmp_path, path).with_context(|| {
|
||||
format!(
|
||||
"failed to rename {} -> {}",
|
||||
tmp_path.display(),
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
fs::rename(&tmp_path, path).with_context(|| {
|
||||
format!(
|
||||
"failed to rename {} -> {}",
|
||||
tmp_path.display(),
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
// Best-effort durability: ensure the directory entry is persisted too.
|
||||
let dir = File::open(parent).with_context(|| format!("failed to open {}", parent.display()))?;
|
||||
dir.sync_all()
|
||||
|
||||
@@ -2,10 +2,12 @@ use crate::config::NetworkMode;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::Context;
|
||||
use rama::Layer;
|
||||
use rama::Service;
|
||||
use rama::layer::AddExtensionLayer;
|
||||
use rama::net::stream::SocketInfo;
|
||||
use rama::proxy::socks5::Socks5Acceptor;
|
||||
use rama::proxy::socks5::server::DefaultConnector;
|
||||
@@ -21,10 +23,13 @@ use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
let listener = TcpListener::build_with_state(state.clone())
|
||||
let listener = TcpListener::build()
|
||||
.bind(addr)
|
||||
.await
|
||||
.map_err(|err| anyhow!("bind SOCKS5 proxy: {err}"))?;
|
||||
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
|
||||
.map_err(rama::error::OpaqueError::from)
|
||||
.map_err(anyhow::Error::from)
|
||||
.with_context(|| format!("bind SOCKS5 proxy: {addr}"))?;
|
||||
|
||||
info!("SOCKS5 proxy listening on {addr}");
|
||||
|
||||
@@ -39,81 +44,80 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
}
|
||||
|
||||
let tcp_connector = TcpConnector::default();
|
||||
let policy_tcp_connector = service_fn(
|
||||
move |ctx: RamaContext<Arc<AppState>>, req: TcpRequest| {
|
||||
let tcp_connector = tcp_connector.clone();
|
||||
async move {
|
||||
let app_state = ctx.state().clone();
|
||||
let authority = req.authority().clone();
|
||||
let host = normalize_host(&authority.host().to_string());
|
||||
let port = authority.port();
|
||||
let client = ctx
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string());
|
||||
let policy_tcp_connector = service_fn(move |ctx: Context<()>, req: TcpRequest| {
|
||||
let tcp_connector = tcp_connector.clone();
|
||||
async move {
|
||||
let app_state = ctx
|
||||
.get::<Arc<AppState>>()
|
||||
.cloned()
|
||||
.ok_or_else(|| io::Error::other("missing state"))?;
|
||||
|
||||
match app_state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
client.clone(),
|
||||
None,
|
||||
Some(NetworkMode::Limited),
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!(
|
||||
"SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Err(
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
|
||||
);
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate method policy: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
let host = normalize_host(&req.authority().host().to_string());
|
||||
let port = req.authority().port();
|
||||
let client = ctx
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string());
|
||||
|
||||
match app_state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
client.clone(),
|
||||
None,
|
||||
Some(NetworkMode::Limited),
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!(
|
||||
"SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Err(io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into());
|
||||
}
|
||||
|
||||
match app_state.host_blocked(&host, port).await {
|
||||
Ok((true, reason)) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("SOCKS blocked (client={client}, host={host}, reason={reason})");
|
||||
return Err(
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
|
||||
);
|
||||
}
|
||||
Ok((false, _)) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("SOCKS allowed (client={client}, host={host}, port={port})");
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate host: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate method policy: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
|
||||
tcp_connector.serve(ctx, req).await
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
match app_state.host_blocked(&host, port).await {
|
||||
Ok((true, reason)) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("SOCKS blocked (client={client}, host={host}, reason={reason})");
|
||||
return Err(io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into());
|
||||
}
|
||||
Ok((false, _)) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("SOCKS allowed (client={client}, host={host}, port={port})");
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate host: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
}
|
||||
|
||||
tcp_connector.serve(ctx, req).await
|
||||
}
|
||||
});
|
||||
|
||||
let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector);
|
||||
let socks_acceptor = Socks5Acceptor::new().with_connector(socks_connector);
|
||||
|
||||
listener.serve(socks_acceptor).await;
|
||||
listener
|
||||
.serve(AddExtensionLayer::new(state).into_layer(socks_acceptor))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use codex_core::config::CONFIG_TOML_FILE;
|
||||
use codex_core::config::ConfigBuilder;
|
||||
use codex_core::config::Constrained;
|
||||
use codex_core::config::ConstraintError;
|
||||
use codex_core::config_loader::RequirementSource;
|
||||
use globset::GlobBuilder;
|
||||
use globset::GlobSet;
|
||||
use globset::GlobSetBuilder;
|
||||
@@ -80,6 +81,14 @@ pub struct AppState {
|
||||
state: Arc<RwLock<ConfigState>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AppState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Avoid logging internal state (config contents, derived globsets, etc.) which can be noisy
|
||||
// and may contain sensitive paths.
|
||||
f.debug_struct("AppState").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new() -> Result<Self> {
|
||||
let cfg_state = build_config_state().await?;
|
||||
@@ -467,11 +476,25 @@ fn validate_policy_against_constraints(
|
||||
config: &Config,
|
||||
constraints: &NetworkProxyConstraints,
|
||||
) -> std::result::Result<(), ConstraintError> {
|
||||
fn invalid_value(
|
||||
field_name: &'static str,
|
||||
candidate: impl Into<String>,
|
||||
allowed: impl Into<String>,
|
||||
) -> ConstraintError {
|
||||
ConstraintError::InvalidValue {
|
||||
field_name,
|
||||
candidate: candidate.into(),
|
||||
allowed: allowed.into(),
|
||||
requirement_source: RequirementSource::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
let enabled = config.network_proxy.enabled;
|
||||
if let Some(max_enabled) = constraints.enabled {
|
||||
let _ = Constrained::new(enabled, move |candidate| {
|
||||
if *candidate && !max_enabled {
|
||||
Err(ConstraintError::invalid_value(
|
||||
Err(invalid_value(
|
||||
"network_proxy.enabled",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
@@ -484,7 +507,8 @@ fn validate_policy_against_constraints(
|
||||
if let Some(max_mode) = constraints.mode {
|
||||
let _ = Constrained::new(config.network_proxy.mode, move |candidate| {
|
||||
if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
|
||||
Err(ConstraintError::invalid_value(
|
||||
Err(invalid_value(
|
||||
"network_proxy.mode",
|
||||
format!("{candidate:?}"),
|
||||
format!("{max_mode:?} or more restrictive"),
|
||||
))
|
||||
@@ -501,7 +525,8 @@ fn validate_policy_against_constraints(
|
||||
Some(true) | None => Ok(()),
|
||||
Some(false) => {
|
||||
if *candidate {
|
||||
Err(ConstraintError::invalid_value(
|
||||
Err(invalid_value(
|
||||
"network_proxy.dangerously_allow_non_loopback_admin",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
@@ -519,7 +544,8 @@ fn validate_policy_against_constraints(
|
||||
Some(true) | None => Ok(()),
|
||||
Some(false) => {
|
||||
if *candidate {
|
||||
Err(ConstraintError::invalid_value(
|
||||
Err(invalid_value(
|
||||
"network_proxy.dangerously_allow_non_loopback_proxy",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
@@ -535,7 +561,8 @@ fn validate_policy_against_constraints(
|
||||
config.network_proxy.policy.allow_local_binding,
|
||||
move |candidate| {
|
||||
if *candidate && !allow_local_binding {
|
||||
Err(ConstraintError::invalid_value(
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allow_local_binding",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
@@ -563,7 +590,8 @@ fn validate_policy_against_constraints(
|
||||
if invalid.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allowed_domains",
|
||||
format!("{invalid:?}"),
|
||||
"subset of managed allowed_domains",
|
||||
))
|
||||
@@ -590,7 +618,8 @@ fn validate_policy_against_constraints(
|
||||
if missing.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.denied_domains",
|
||||
"missing managed denied_domains entries",
|
||||
format!("{missing:?}"),
|
||||
))
|
||||
@@ -616,7 +645,8 @@ fn validate_policy_against_constraints(
|
||||
if invalid.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allow_unix_sockets",
|
||||
format!("{invalid:?}"),
|
||||
"subset of managed allow_unix_sockets",
|
||||
))
|
||||
@@ -766,7 +796,9 @@ mod tests {
|
||||
(false, String::new())
|
||||
);
|
||||
assert_eq!(
|
||||
state.host_blocked("not-example.com", 80).await.unwrap(),
|
||||
// Use a public IP literal to avoid relying on ambient DNS behavior (some networks
|
||||
// resolve unknown hostnames to private IPs, which would trigger `not_allowed_local`).
|
||||
state.host_blocked("8.8.8.8", 80).await.unwrap(),
|
||||
(true, "not_allowed".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user