use rama instead of implementing our own proxy stack

This commit is contained in:
viyatb-oai
2025-12-21 21:49:51 -08:00
parent eceb76bf3d
commit 83e8a702fb
8 changed files with 1990 additions and 870 deletions

View File

@@ -1,446 +1,450 @@
use crate::config::NetworkMode;
use crate::mitm;
use crate::policy::normalize_host;
use crate::responses::blocked_text;
use crate::responses::json_blocked;
use crate::responses::text_response;
use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Context;
use anyhow::Result;
use hyper::Body;
use hyper::Method;
use hyper::Request;
use hyper::Response;
use hyper::Server;
use hyper::StatusCode;
use hyper::Uri;
use hyper::body::to_bytes;
use hyper::header::HOST;
use hyper::header::HeaderName;
use hyper::service::make_service_fn;
use hyper::service::service_fn;
use std::collections::HashSet;
use anyhow::anyhow;
use rama::Context as RamaContext;
use rama::Layer;
use rama::Service;
use rama::http::Body;
use rama::http::Request;
use rama::http::Response;
use rama::http::StatusCode;
use rama::http::client::EasyHttpWebClient;
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
use rama::http::layer::upgrade::UpgradeLayer;
use rama::http::layer::upgrade::Upgraded;
use rama::http::matcher::MethodMatcher;
use rama::http::server::HttpServer;
use rama::net::http::RequestContext;
use rama::net::proxy::ProxyTarget;
use rama::net::stream::SocketInfo;
use rama::service::service_fn;
use rama::tcp::client::service::Forwarder;
use rama::tcp::server::TcpListener;
use serde_json::json;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::copy_bidirectional;
use tokio::net::TcpStream;
use tracing::error;
use tracing::info;
use tracing::warn;
type ContextState = Arc<AppState>;
type ProxyContext = RamaContext<ContextState>;
pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
let make_svc = make_service_fn(move |conn: &hyper::server::conn::AddrStream| {
let state = state.clone();
let client_addr = conn.remote_addr();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
handle_proxy_request(req, state.clone(), client_addr)
}))
}
});
let server = Server::bind(&addr).serve(make_svc);
let listener = TcpListener::build_with_state(state)
.bind(addr)
.await
.map_err(|err| anyhow!("bind HTTP proxy: {err}"))?;
let http_service = HttpServer::auto(rama::rt::Executor::new()).service(
(
UpgradeLayer::new(
MethodMatcher::CONNECT,
service_fn(http_connect_accept),
service_fn(http_connect_proxy),
),
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
)
.into_layer(service_fn(http_plain_proxy)),
);
info!(addr = %addr, "HTTP proxy listening");
server.await?;
listener.serve(http_service).await;
Ok(())
}
async fn handle_proxy_request(
req: Request<Body>,
state: Arc<AppState>,
client_addr: SocketAddr,
) -> Result<Response<Body>, Infallible> {
let response = if req.method() == Method::CONNECT {
handle_connect(req, state, client_addr).await
} else {
handle_http_forward(req, state, client_addr).await
async fn http_connect_accept(
mut ctx: ProxyContext,
req: Request,
) -> Result<(Response, ProxyContext, Request), Response> {
let authority = match ctx
.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
.map(|ctx| ctx.authority.clone())
{
Ok(authority) => authority,
Err(err) => {
warn!(error = %err, "CONNECT missing authority");
return Err(text_response(StatusCode::BAD_REQUEST, "missing authority"));
}
};
Ok(response)
}
async fn handle_connect(
req: Request<Body>,
state: Arc<AppState>,
client_addr: SocketAddr,
) -> Response<Body> {
let authority = match req.uri().authority() {
Some(auth) => auth.as_str().to_string(),
None => return text_response(StatusCode::BAD_REQUEST, "missing authority"),
};
let (authority_host, target_port) = split_authority(&authority);
let host = normalize_host(&authority_host);
let host = normalize_host(&authority.host().to_string());
if host.is_empty() {
return text_response(StatusCode::BAD_REQUEST, "invalid host");
return Err(text_response(StatusCode::BAD_REQUEST, "invalid host"));
}
match state.host_blocked(&host).await {
let app_state = ctx.state().clone();
let client = client_addr(&ctx);
match app_state.host_blocked(&host).await {
Ok((true, reason)) => {
let _ = state
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
Some(client_addr.to_string()),
client.clone(),
Some("CONNECT".to_string()),
None,
"http-connect".to_string(),
))
.await;
warn!(client = %client_addr, host = %host, reason = %reason, "CONNECT blocked");
return blocked_text(&reason);
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
reason = %reason,
"CONNECT blocked"
);
return Err(blocked_text(&reason));
}
Ok((false, _)) => {
info!(client = %client_addr, host = %host, "CONNECT allowed");
info!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
"CONNECT allowed"
);
}
Err(err) => {
error!(error = %err, "failed to evaluate host");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
let mode = match state.network_mode().await {
let mode = match app_state.network_mode().await {
Ok(mode) => mode,
Err(err) => {
error!(error = %err, "failed to read network mode");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
let mitm_state = match state.mitm_state().await {
let mitm_state = match app_state.mitm_state().await {
Ok(state) => state,
Err(err) => {
error!(error = %err, "failed to load MITM state");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
if mode == NetworkMode::Limited && mitm_state.is_none() {
let _ = state
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"mitm_required".to_string(),
Some(client_addr.to_string()),
client.clone(),
Some("CONNECT".to_string()),
Some(NetworkMode::Limited),
"http-connect".to_string(),
))
.await;
warn!(
client = %client_addr,
client = %client.as_deref().unwrap_or_default(),
host = %host,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"CONNECT blocked; MITM required for read-only HTTPS in limited mode"
);
return blocked_text("mitm_required");
return Err(blocked_text("mitm_required"));
}
let on_upgrade = hyper::upgrade::on(req);
tokio::spawn(async move {
match on_upgrade.await {
Ok(upgraded) => {
if let Some(mitm_state) = mitm_state {
info!(client = %client_addr, host = %host, mode = ?mode, "CONNECT MITM enabled");
if let Err(err) =
mitm::mitm_tunnel(upgraded, &host, target_port, mode, mitm_state).await
{
warn!(error = %err, "MITM tunnel error");
}
return;
}
let mut upgraded = upgraded;
match TcpStream::connect(&authority).await {
Ok(mut server_stream) => {
if let Err(err) =
copy_bidirectional(&mut upgraded, &mut server_stream).await
{
warn!(error = %err, "tunnel error");
}
}
Err(err) => {
warn!(error = %err, "failed to connect to upstream");
}
}
}
Err(err) => warn!(error = %err, "upgrade failed"),
}
});
ctx.insert(ProxyTarget(authority));
ctx.insert(mode);
if let Some(mitm_state) = mitm_state {
ctx.insert(mitm_state);
}
Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))
Ok((
Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty())),
ctx,
req,
))
}
async fn handle_http_forward(
req: Request<Body>,
state: Arc<AppState>,
client_addr: SocketAddr,
) -> Response<Body> {
let (parts, body) = req.into_parts();
let method_allowed = match state.method_allowed(&parts.method).await {
async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(), Infallible> {
let mode = ctx
.get::<NetworkMode>()
.copied()
.unwrap_or(NetworkMode::Full);
let authority = match ctx.get::<ProxyTarget>().map(|target| target.0.clone()) {
Some(authority) => authority,
None => {
warn!("CONNECT missing proxy target");
return Ok(());
}
};
let host = normalize_host(&authority.host().to_string());
if let Some(mitm_state) = ctx.get::<Arc<mitm::MitmState>>().cloned() {
info!(host = %host, port = authority.port(), mode = ?mode, "CONNECT MITM enabled");
if let Err(err) = mitm::mitm_tunnel(
ctx,
upgraded,
host.as_str(),
authority.port(),
mode,
mitm_state,
)
.await
{
warn!(error = %err, "MITM tunnel error");
}
return Ok(());
}
let forwarder = Forwarder::ctx();
if let Err(err) = forwarder.serve(ctx, upgraded).await {
warn!(error = %err, "tunnel error");
}
Ok(())
}
async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Response, Infallible> {
let app_state = ctx.state().clone();
let client = client_addr(&ctx);
let method_allowed = match app_state.method_allowed(req.method().as_str()).await {
Ok(allowed) => allowed,
Err(err) => {
error!(error = %err, "failed to evaluate method policy");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
let unix_socket = parts
.headers
if let Some(socket_path) = req
.headers()
.get("x-unix-socket")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());
if let Some(socket_path) = unix_socket {
.map(|v| v.to_string())
{
if !method_allowed {
warn!(
client = %client_addr,
method = %parts.method,
client = %client.as_deref().unwrap_or_default(),
method = %req.method(),
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"unix socket blocked by method policy"
);
return json_blocked("unix-socket", "method_not_allowed");
return Ok(json_blocked("unix-socket", "method_not_allowed"));
}
if !cfg!(target_os = "macos") {
warn!(path = %socket_path, "unix socket proxy unsupported on this platform");
return text_response(StatusCode::NOT_IMPLEMENTED, "unix sockets unsupported");
return Ok(text_response(
StatusCode::NOT_IMPLEMENTED,
"unix sockets unsupported",
));
}
match state.is_unix_socket_allowed(&socket_path).await {
match app_state.is_unix_socket_allowed(&socket_path).await {
Ok(true) => {
info!(client = %client_addr, path = %socket_path, "unix socket allowed");
match proxy_via_unix_socket(Request::from_parts(parts, body), &socket_path).await {
Ok(resp) => return resp,
info!(
client = %client.as_deref().unwrap_or_default(),
path = %socket_path,
"unix socket allowed"
);
match proxy_via_unix_socket(ctx, req, &socket_path).await {
Ok(resp) => return Ok(resp),
Err(err) => {
warn!(error = %err, "unix socket proxy failed");
return text_response(StatusCode::BAD_GATEWAY, "unix socket proxy failed");
return Ok(text_response(
StatusCode::BAD_GATEWAY,
"unix socket proxy failed",
));
}
}
}
Ok(false) => {
warn!(client = %client_addr, path = %socket_path, "unix socket blocked");
return json_blocked("unix-socket", "not_allowed");
warn!(
client = %client.as_deref().unwrap_or_default(),
path = %socket_path,
"unix socket blocked"
);
return Ok(json_blocked("unix-socket", "not_allowed"));
}
Err(err) => {
warn!(error = %err, "unix socket check failed");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
}
let host_header = parts
.headers
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string())
.or_else(|| parts.uri.authority().map(|a| a.as_str().to_string()));
let authority = match host_header {
Some(h) => h,
None => return text_response(StatusCode::BAD_REQUEST, "missing host"),
let authority = match ctx
.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
.map(|ctx| ctx.authority.clone())
{
Ok(authority) => authority,
Err(err) => {
warn!(error = %err, "missing host");
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
}
};
let authority = authority.trim().to_string();
let host = normalize_host(&authority);
if host.is_empty() {
return text_response(StatusCode::BAD_REQUEST, "invalid host");
}
let host = normalize_host(&authority.host().to_string());
match state.host_blocked(&host).await {
match app_state.host_blocked(&host).await {
Ok((true, reason)) => {
let _ = state
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
Some(client_addr.to_string()),
Some(parts.method.to_string()),
client.clone(),
Some(req.method().as_str().to_string()),
None,
"http".to_string(),
))
.await;
warn!(client = %client_addr, host = %host, reason = %reason, "request blocked");
return json_blocked(&host, &reason);
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
reason = %reason,
"request blocked"
);
return Ok(json_blocked(&host, &reason));
}
Ok((false, _)) => {}
Err(err) => {
error!(error = %err, "failed to evaluate host");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
if !method_allowed {
let _ = state
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"method_not_allowed".to_string(),
Some(client_addr.to_string()),
Some(parts.method.to_string()),
client.clone(),
Some(req.method().as_str().to_string()),
Some(NetworkMode::Limited),
"http".to_string(),
))
.await;
warn!(
client = %client_addr,
client = %client.as_deref().unwrap_or_default(),
host = %host,
method = %parts.method,
method = %req.method(),
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"request blocked by method policy"
);
return json_blocked(&host, "method_not_allowed");
return Ok(json_blocked(&host, "method_not_allowed"));
}
info!(
client = %client_addr,
client = %client.as_deref().unwrap_or_default(),
host = %host,
method = %parts.method,
method = %req.method(),
"request allowed"
);
let uri = match build_forward_uri(&authority, &parts.uri) {
Ok(uri) => uri,
Err(err) => {
warn!(error = %err, "failed to build upstream uri");
return text_response(StatusCode::BAD_REQUEST, "invalid uri");
}
};
let body_bytes = match to_bytes(body).await {
Ok(bytes) => bytes,
Err(err) => {
warn!(error = %err, "failed to read body");
return text_response(StatusCode::BAD_GATEWAY, "failed to read body");
}
};
let mut builder = Request::builder()
.method(parts.method)
.uri(uri)
.version(parts.version);
let hop_headers = hop_by_hop_headers();
for (name, value) in parts.headers.iter() {
let name_str = name.as_str().to_ascii_lowercase();
if hop_headers.contains(name_str.as_str())
|| name == &HeaderName::from_static("x-unix-socket")
{
continue;
}
builder = builder.header(name, value);
}
let forwarded_req = match builder.body(Body::from(body_bytes)) {
Ok(req) => req,
Err(err) => {
warn!(error = %err, "failed to build request");
return text_response(StatusCode::BAD_GATEWAY, "invalid request");
}
};
match state.client.request(forwarded_req).await {
Ok(resp) => filter_response(resp),
let client = EasyHttpWebClient::default();
match client.serve(ctx, req).await {
Ok(resp) => Ok(resp),
Err(err) => {
warn!(error = %err, "upstream request failed");
text_response(StatusCode::BAD_GATEWAY, "upstream failure")
Ok(text_response(StatusCode::BAD_GATEWAY, "upstream failure"))
}
}
}
fn build_forward_uri(authority: &str, uri: &Uri) -> Result<Uri> {
let path = path_and_query(uri);
let target = format!("http://{authority}{path}");
Ok(target.parse()?)
}
fn filter_response(resp: Response<Body>) -> Response<Body> {
let mut builder = Response::builder().status(resp.status());
let hop_headers = hop_by_hop_headers();
for (name, value) in resp.headers().iter() {
if hop_headers.contains(name.as_str().to_ascii_lowercase().as_str()) {
continue;
}
builder = builder.header(name, value);
}
builder
.body(resp.into_body())
.unwrap_or_else(|_| Response::new(Body::from("proxy error")))
}
fn path_and_query(uri: &Uri) -> String {
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/")
.to_string()
}
fn hop_by_hop_headers() -> HashSet<&'static str> {
[
"connection",
"proxy-connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
]
.into_iter()
.collect()
}
fn split_authority(authority: &str) -> (String, u16) {
if let Some(host) = authority.strip_prefix('[') {
if let Some(end) = host.find(']') {
let hostname = host[..end].to_string();
let port = host[end + 1..]
.strip_prefix(':')
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(443);
return (hostname, port);
}
}
let mut parts = authority.splitn(2, ':');
let host = parts.next().unwrap_or("").to_string();
let port = parts
.next()
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(443);
(host, port)
}
async fn proxy_via_unix_socket(req: Request<Body>, socket_path: &str) -> Result<Response<Body>> {
async fn proxy_via_unix_socket(
ctx: ProxyContext,
req: Request,
socket_path: &str,
) -> Result<Response> {
#[cfg(target_os = "macos")]
{
use hyper::client::conn::Builder as ConnBuilder;
use tokio::net::UnixStream;
use rama::unix::client::UnixConnector;
let path = path_and_query(req.uri());
let (parts, body) = req.into_parts();
let body_bytes = to_bytes(body).await?;
let mut builder = Request::builder()
.method(parts.method)
.uri(path)
.version(parts.version);
let hop_headers = hop_by_hop_headers();
for (name, value) in parts.headers.iter() {
let name_str = name.as_str().to_ascii_lowercase();
if hop_headers.contains(name_str.as_str())
|| name == &HeaderName::from_static("x-unix-socket")
{
continue;
}
builder = builder.header(name, value);
}
let req = builder.body(Body::from(body_bytes))?;
let stream = UnixStream::connect(socket_path).await?;
let (mut sender, conn) = ConnBuilder::new().handshake(stream).await?;
tokio::spawn(async move {
if let Err(err) = conn.await {
warn!(error = %err, "unix socket connection error");
}
});
Ok(sender.send_request(req).await?)
let client = EasyHttpWebClient::builder()
.with_custom_transport_connector(UnixConnector::fixed(socket_path))
.without_tls_proxy_support()
.without_proxy_support()
.without_tls_support()
.build();
let (mut parts, body) = req.into_parts();
let path = parts
.uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
parts.uri = path
.parse()
.with_context(|| format!("invalid unix socket request path: {path}"))?;
parts.headers.remove("x-unix-socket");
let req = Request::from_parts(parts, body);
Ok(client.serve(ctx, req).await?)
}
#[cfg(not(target_os = "macos"))]
{
let _ = ctx;
let _ = req;
let _ = socket_path;
Err(anyhow::anyhow!("unix sockets not supported"))
}
}
fn client_addr(ctx: &ProxyContext) -> Option<String> {
ctx.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string())
}
fn json_blocked(host: &str, reason: &str) -> Response {
let body = Body::from(json!({"status":"blocked","host":host,"reason":reason}).to_string());
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "application/json")
.header("x-proxy-error", blocked_header_value(reason))
.body(body)
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}
fn blocked_text(reason: &str) -> Response {
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "text/plain")
.header("x-proxy-error", blocked_header_value(reason))
.body(Body::from(blocked_message(reason)))
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}
fn text_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(Body::from(body.to_string()))
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}
fn blocked_header_value(reason: &str) -> &'static str {
match reason {
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
"denied" => "blocked-by-denylist",
"method_not_allowed" => "blocked-by-method-policy",
"mitm_required" => "blocked-by-mitm-required",
_ => "blocked-by-policy",
}
}
fn blocked_message(reason: &str) -> &'static str {
match reason {
"not_allowed" => "Codex blocked this request: domain not in allowlist.",
"not_allowed_local" => "Codex blocked this request: local addresses not allowed.",
"denied" => "Codex blocked this request: domain denied by policy.",
"method_not_allowed" => "Codex blocked this request: method not allowed in limited mode.",
"mitm_required" => "Codex blocked this request: MITM required for limited HTTPS.",
_ => "Codex blocked this request by network policy.",
}
}

View File

@@ -4,21 +4,46 @@ mod imp {
use crate::config::NetworkMode;
use crate::policy::method_allowed;
use crate::policy::normalize_host;
use crate::responses::text_response;
use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use hyper::Body;
use hyper::Method;
use hyper::Request;
use hyper::Response;
use hyper::StatusCode;
use hyper::Uri;
use hyper::Version;
use hyper::body::HttpBody;
use hyper::header::HOST;
use hyper::server::conn::Http;
use hyper::service::service_fn;
use rama::Context as RamaContext;
use rama::Layer;
use rama::Service;
use rama::bytes::Bytes;
use rama::error::BoxError;
use rama::error::OpaqueError;
use rama::futures::stream::Stream;
use rama::http::Body;
use rama::http::HeaderValue;
use rama::http::Request;
use rama::http::Response;
use rama::http::StatusCode;
use rama::http::Uri;
use rama::http::header::HOST;
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
use rama::http::layer::upgrade::Upgraded;
use rama::http::server::HttpServer;
use rama::net::proxy::ProxyTarget;
use rama::net::stream::SocketInfo;
use rama::service::service_fn;
use rama::tls::rustls::dep::pemfile;
use rama::tls::rustls::server::TlsAcceptorData;
use rama::tls::rustls::server::TlsAcceptorDataBuilder;
use rama::tls::rustls::server::TlsAcceptorLayer;
use std::fs;
use std::io::BufReader;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context as TaskContext;
use std::task::Poll;
use tracing::info;
use tracing::warn;
use rcgen::BasicConstraints;
use rcgen::Certificate;
use rcgen::CertificateParams;
@@ -29,62 +54,11 @@ mod imp {
use rcgen::KeyPair;
use rcgen::KeyUsagePurpose;
use rcgen::SanType;
use rustls::Certificate as RustlsCertificate;
use rustls::ClientConfig;
use rustls::PrivateKey;
use rustls::RootCertStore;
use rustls::ServerConfig;
use std::collections::HashSet;
use std::convert::Infallible;
use std::fs;
use std::io::Cursor;
use std::net::IpAddr;
use std::path::Path;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::TlsConnector;
use tracing::info;
use tracing::warn;
#[derive(Clone, Copy, Debug)]
enum MitmProtocol {
Http1,
Http2,
}
struct MitmTarget {
host: String,
port: u16,
}
impl MitmTarget {
fn authority(&self) -> String {
if self.port == 443 {
self.host.clone()
} else {
format!("{}:{}", self.host, self.port)
}
}
}
struct RequestLogContext {
host: String,
method: Method,
path: String,
}
struct ResponseLogContext {
host: String,
method: Method,
path: String,
status: StatusCode,
}
pub struct MitmState {
ca_key: KeyPair,
ca_cert: Certificate,
client_config: Arc<ClientConfig>,
upstream: rama::service::BoxService<Arc<AppState>, Request, Response, OpaqueError>,
inspect: bool,
max_body_bytes: usize,
}
@@ -98,30 +72,44 @@ mod imp {
let ca_cert = ca_params
.self_signed(&ca_key)
.context("failed to reconstruct CA cert")?;
let client_config = build_client_config()?;
let tls_config = rama::tls::rustls::client::TlsConnectorData::new_http_auto()
.context("create upstream TLS config")?;
let upstream = rama::http::client::EasyHttpWebClient::builder()
.with_default_transport_connector()
.without_tls_proxy_support()
.without_proxy_support()
.with_tls_support_using_rustls(Some(tls_config))
.build()
.boxed();
Ok(Self {
ca_key,
ca_cert,
client_config,
upstream,
inspect: cfg.inspect,
max_body_bytes: cfg.max_body_bytes,
})
}
pub fn server_config_for_host(&self, host: &str) -> Result<Arc<ServerConfig>> {
let (certs, key) = issue_host_certificate(host, &self.ca_cert, &self.ca_key)?;
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.context("failed to build server TLS config")?;
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
let (cert_pem, key_pem) =
issue_host_certificate_pem(host, &self.ca_cert, &self.ca_key)?;
let cert_chain = pemfile::certs(&mut BufReader::new(cert_pem.as_bytes()))
.collect::<Result<Vec<_>, _>>()
.context("failed to parse host cert PEM")?;
if cert_chain.is_empty() {
return Err(anyhow!("no certificates found"));
}
pub fn client_config(&self) -> Arc<ClientConfig> {
Arc::clone(&self.client_config)
let key_der = pemfile::private_key(&mut BufReader::new(key_pem.as_bytes()))
.context("failed to parse host key PEM")?
.context("no private key found")?;
Ok(TlsAcceptorDataBuilder::new(cert_chain, key_der)
.context("failed to build rustls acceptor config")?
.with_alpn_protocols_http_auto()
.build())
}
pub fn inspect_enabled(&self) -> bool {
@@ -134,97 +122,87 @@ mod imp {
}
pub async fn mitm_tunnel(
stream: hyper::upgrade::Upgraded,
mut ctx: RamaContext<Arc<AppState>>,
upgraded: Upgraded,
host: &str,
port: u16,
_port: u16,
mode: NetworkMode,
state: Arc<MitmState>,
) -> Result<()> {
let server_config = state.server_config_for_host(host)?;
let acceptor = TlsAcceptor::from(server_config);
let tls_stream = acceptor
.accept(stream)
.await
.context("client TLS handshake failed")?;
let protocol = match tls_stream.get_ref().1.alpn_protocol() {
Some(proto) if proto == b"h2" => MitmProtocol::Http2,
_ => MitmProtocol::Http1,
};
info!(
host = %host,
port = port,
protocol = ?protocol,
mode = ?mode,
inspect = state.inspect_enabled(),
max_body_bytes = state.max_body_bytes(),
"MITM TLS established"
// Ensure the MITM state is available for the per-request handler.
ctx.insert(state.clone());
ctx.insert(mode);
let acceptor_data = state.tls_acceptor_data_for_host(host)?;
let http_service = HttpServer::auto(ctx.executor().clone()).service(
(
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
)
.into_layer(service_fn(handle_mitm_request)),
);
let target = Arc::new(MitmTarget {
host: host.to_string(),
port,
});
let service = {
let state = state.clone();
let target = target.clone();
service_fn(move |req| handle_mitm_request(req, target.clone(), mode, state.clone()))
};
let https_service = TlsAcceptorLayer::new(acceptor_data)
.with_store_client_hello(true)
.into_layer(http_service);
let mut http = Http::new();
match protocol {
MitmProtocol::Http2 => {
http.http2_only(true);
}
MitmProtocol::Http1 => {
http.http1_only(true);
}
}
http.serve_connection(tls_stream, service)
https_service
.serve(ctx, upgraded)
.await
.context("MITM HTTP handling failed")?;
.map_err(|err| anyhow!("MITM serve error: {err}"))?;
Ok(())
}
async fn handle_mitm_request(
req: Request<Body>,
target: Arc<MitmTarget>,
mode: NetworkMode,
state: Arc<MitmState>,
) -> Result<Response<Body>, Infallible> {
let response = match forward_request(req, target.as_ref(), mode, state.as_ref()).await {
ctx: RamaContext<Arc<AppState>>,
req: Request,
) -> Result<Response, std::convert::Infallible> {
let response = match forward_request(ctx, req).await {
Ok(resp) => resp,
Err(err) => {
warn!(error = %err, host = %target.host, "MITM upstream request failed");
warn!(error = %err, "MITM upstream request failed");
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
}
};
Ok(response)
}
async fn forward_request(
req: Request<Body>,
target: &MitmTarget,
mode: NetworkMode,
state: &MitmState,
) -> Result<Response<Body>> {
if req.method() == Method::CONNECT {
async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Result<Response> {
let target = ctx
.get::<ProxyTarget>()
.context("missing proxy target")?
.0
.clone();
let target_host = normalize_host(&target.host().to_string());
let target_port = target.port();
let mode = ctx
.get::<NetworkMode>()
.copied()
.unwrap_or(NetworkMode::Full);
let mitm = ctx
.get::<Arc<MitmState>>()
.cloned()
.context("missing MITM state")?;
if req.method().as_str() == "CONNECT" {
return Ok(text_response(
StatusCode::METHOD_NOT_ALLOWED,
"CONNECT not supported inside MITM",
));
}
let (parts, body) = req.into_parts();
let request_version = parts.version;
let method = parts.method.clone();
let inspect = state.inspect_enabled();
let max_body_bytes = state.max_body_bytes();
let method = req.method().as_str().to_string();
let path = path_and_query(req.uri());
let client = ctx
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
if let Some(request_host) = extract_request_host(&parts) {
if let Some(request_host) = extract_request_host(&req) {
let normalized = normalize_host(&request_host);
if !normalized.is_empty() && normalized != target.host {
if !normalized.is_empty() && normalized != target_host {
warn!(
target = %target.host,
target = %target_host,
request_host = %normalized,
"MITM host mismatch"
);
@@ -232,175 +210,143 @@ mod imp {
}
}
let path = path_and_query(&parts.uri);
let uri = build_origin_form_uri(&path)?;
let authority = target.authority();
if !method_allowed(mode, &method) {
if !method_allowed(mode, method.as_str()) {
let _ = ctx
.state()
.record_blocked(BlockedRequest::new(
target_host.clone(),
"method_not_allowed".to_string(),
client.clone(),
Some(method.clone()),
Some(NetworkMode::Limited),
"https".to_string(),
))
.await;
warn!(
host = %authority,
host = %target_host,
method = %method,
path = %path,
mode = ?mode,
allowed_methods = "GET, HEAD, OPTIONS",
"MITM blocked by method policy"
);
return Ok(text_response(StatusCode::FORBIDDEN, "method not allowed"));
return Ok(blocked_text("method_not_allowed"));
}
let mut builder = Request::builder()
.method(method.clone())
.uri(uri)
.version(Version::HTTP_11);
let hop_headers = hop_by_hop_headers();
for (name, value) in parts.headers.iter() {
let name_str = name.as_str().to_ascii_lowercase();
if hop_headers.contains(name_str.as_str()) || name == &HOST {
continue;
}
builder = builder.header(name, value);
}
builder = builder.header(HOST, authority.as_str());
let (mut parts, body) = req.into_parts();
let authority = authority_header_value(&target_host, target_port);
parts.uri = build_https_uri(&authority, &path)?;
parts
.headers
.insert(HOST, HeaderValue::from_str(&authority)?);
let inspect = mitm.inspect_enabled();
let max_body_bytes = mitm.max_body_bytes();
let body = if inspect {
let (tx, out_body) = Body::channel();
let ctx = RequestLogContext {
host: authority.clone(),
method: method.clone(),
path: path.clone(),
};
tokio::spawn(async move {
stream_body(body, tx, max_body_bytes, ctx).await;
});
out_body
inspect_body(
body,
max_body_bytes,
RequestLogContext {
host: authority.clone(),
method: method.clone(),
path: path.clone(),
},
)
} else {
body
};
let upstream_req = builder
.body(body)
.context("failed to build upstream request")?;
let upstream_resp = send_upstream_request(upstream_req, target, state).await?;
let upstream_req = Request::from_parts(parts, body);
let upstream_resp = mitm.upstream.serve(ctx, upstream_req).await?;
respond_with_inspection(
upstream_resp,
request_version,
inspect,
max_body_bytes,
&method,
&path,
&authority,
)
.await
}
async fn send_upstream_request(
req: Request<Body>,
target: &MitmTarget,
state: &MitmState,
) -> Result<Response<Body>> {
let upstream = TcpStream::connect((target.host.as_str(), target.port))
.await
.context("failed to connect to upstream")?;
let server_name = match target.host.parse::<IpAddr>() {
Ok(ip) => rustls::ServerName::IpAddress(ip),
Err(_) => rustls::ServerName::try_from(target.host.as_str())
.map_err(|_| anyhow!("invalid server name"))?,
};
let connector = TlsConnector::from(state.client_config());
let tls_stream = connector
.connect(server_name, upstream)
.await
.context("upstream TLS handshake failed")?;
let (mut sender, conn) = hyper::client::conn::Builder::new()
.handshake(tls_stream)
.await
.context("upstream HTTP handshake failed")?;
tokio::spawn(async move {
if let Err(err) = conn.await {
warn!(error = %err, "MITM upstream connection error");
}
});
let resp = sender
.send_request(req)
.await
.context("upstream request failed")?;
Ok(resp)
}
async fn respond_with_inspection(
resp: Response<Body>,
request_version: Version,
fn respond_with_inspection(
resp: Response,
inspect: bool,
max_body_bytes: usize,
method: &Method,
method: &str,
path: &str,
authority: &str,
) -> Result<Response<Body>> {
let (parts, body) = resp.into_parts();
let mut builder = Response::builder()
.status(parts.status)
.version(request_version);
let hop_headers = hop_by_hop_headers();
for (name, value) in parts.headers.iter() {
if hop_headers.contains(name.as_str().to_ascii_lowercase().as_str()) {
continue;
}
builder = builder.header(name, value);
) -> Result<Response> {
if !inspect {
return Ok(resp);
}
let body = if inspect {
let (tx, out_body) = Body::channel();
let ctx = ResponseLogContext {
let (parts, body) = resp.into_parts();
let body = inspect_body(
body,
max_body_bytes,
ResponseLogContext {
host: authority.to_string(),
method: method.clone(),
method: method.to_string(),
path: path.to_string(),
status: parts.status,
};
tokio::spawn(async move {
stream_body(body, tx, max_body_bytes, ctx).await;
});
out_body
} else {
body
};
Ok(builder
.body(body)
.unwrap_or_else(|_| Response::new(Body::from("proxy error"))))
},
);
Ok(Response::from_parts(parts, body))
}
async fn stream_body<T>(
mut body: Body,
mut tx: hyper::body::Sender,
fn inspect_body<T: BodyLoggable + Send + 'static>(
body: Body,
max_body_bytes: usize,
ctx: T,
) where
T: BodyLoggable,
{
let mut len: usize = 0;
let mut truncated = false;
while let Some(chunk) = body.data().await {
match chunk {
Ok(bytes) => {
len = len.saturating_add(bytes.len());
if len > max_body_bytes {
truncated = true;
}
if tx.send_data(bytes).await.is_err() {
break;
}
) -> Body {
Body::from_stream(InspectStream {
inner: Box::pin(body.into_data_stream()),
ctx: Some(Box::new(ctx)),
len: 0,
max_body_bytes,
})
}
struct InspectStream<T> {
inner: Pin<Box<rama::http::BodyDataStream>>,
ctx: Option<Box<T>>,
len: usize,
max_body_bytes: usize,
}
impl<T: BodyLoggable> Stream for InspectStream<T> {
type Item = Result<Bytes, BoxError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.len = this.len.saturating_add(bytes.len());
Poll::Ready(Some(Ok(bytes)))
}
Err(err) => {
warn!(error = %err, "MITM body stream error");
break;
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => {
if let Some(ctx) = this.ctx.take() {
ctx.log(this.len, this.len > this.max_body_bytes);
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
if let Ok(Some(trailers)) = body.trailers().await {
let _ = tx.send_trailers(trailers).await;
}
ctx.log(len, truncated);
}
struct RequestLogContext {
host: String,
method: String,
path: String,
}
struct ResponseLogContext {
host: String,
method: String,
path: String,
status: StatusCode,
}
trait BodyLoggable {
@@ -434,13 +380,32 @@ mod imp {
}
}
fn extract_request_host(parts: &hyper::http::request::Parts) -> Option<String> {
parts
.headers
fn extract_request_host(req: &Request) -> Option<String> {
req.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string())
.or_else(|| parts.uri.authority().map(|a| a.as_str().to_string()))
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
}
fn authority_header_value(host: &str, port: u16) -> String {
// Host header / URI authority formatting.
if host.contains(':') {
if port == 443 {
format!("[{host}]")
} else {
format!("[{host}]:{port}")
}
} else if port == 443 {
host.to_string()
} else {
format!("{host}:{port}")
}
}
fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
let target = format!("https://{authority}{path}");
Ok(target.parse()?)
}
fn path_and_query(uri: &Uri) -> String {
@@ -450,51 +415,11 @@ mod imp {
.to_string()
}
fn build_origin_form_uri(path: &str) -> Result<Uri> {
path.parse().context("invalid request path")
}
fn hop_by_hop_headers() -> HashSet<&'static str> {
[
"connection",
"proxy-connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
]
.into_iter()
.collect()
}
fn build_client_config() -> Result<Arc<ClientConfig>> {
let mut roots = RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs()
.map_err(|err| anyhow!("failed to load native certs: {err}"))?;
for cert in certs {
if roots.add(&RustlsCertificate(cert.0)).is_err() {
warn!("skipping invalid root cert");
}
}
if roots.is_empty() {
return Err(anyhow!("no root certificates available"));
}
let mut config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
fn issue_host_certificate(
fn issue_host_certificate_pem(
host: &str,
ca_cert: &Certificate,
ca_key: &KeyPair,
) -> Result<(Vec<RustlsCertificate>, PrivateKey)> {
) -> Result<(String, String)> {
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
let mut params = CertificateParams::new(Vec::new())
.map_err(|err| anyhow!("failed to create cert params: {err}"))?;
@@ -504,6 +429,7 @@ mod imp {
CertificateParams::new(vec![host.to_string()])
.map_err(|err| anyhow!("failed to create cert params: {err}"))?
};
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
@@ -516,16 +442,13 @@ mod imp {
.signed_by(&key_pair, ca_cert, ca_key)
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
let certs = certs_from_pem(&cert_pem)?;
let key = private_key_from_pem(&key_pem)?;
Ok((certs, key))
Ok((cert.pem(), key_pair.serialize_pem()))
}
fn load_or_create_ca(cfg: &MitmConfig) -> Result<(String, String)> {
let cert_path = &cfg.ca_cert_path;
let key_path = &cfg.ca_key_path;
if cert_path.exists() || key_path.exists() {
if !cert_path.exists() || !key_path.exists() {
return Err(anyhow!("both ca_cert_path and ca_key_path must exist"));
@@ -574,63 +497,75 @@ mod imp {
let cert = params
.self_signed(&key_pair)
.map_err(|err| anyhow!("failed to generate CA cert: {err}"))?;
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
Ok((cert_pem, key_pem))
Ok((cert.pem(), key_pair.serialize_pem()))
}
fn certs_from_pem(pem: &str) -> Result<Vec<RustlsCertificate>> {
let mut reader = Cursor::new(pem);
let certs = rustls_pemfile::certs(&mut reader).context("failed to parse cert PEM")?;
if certs.is_empty() {
return Err(anyhow!("no certificates found"));
}
Ok(certs.into_iter().map(RustlsCertificate).collect())
}
fn private_key_from_pem(pem: &str) -> Result<PrivateKey> {
let mut reader = Cursor::new(pem);
let mut keys =
rustls_pemfile::pkcs8_private_keys(&mut reader).context("failed to parse pkcs8 key")?;
if let Some(key) = keys.pop() {
return Ok(PrivateKey(key));
}
let mut reader = Cursor::new(pem);
let mut keys =
rustls_pemfile::rsa_private_keys(&mut reader).context("failed to parse rsa key")?;
if let Some(key) = keys.pop() {
return Ok(PrivateKey(key));
}
Err(anyhow!("no private key found"))
}
fn write_private_file(path: &Path, contents: &[u8], mode: u32) -> Result<()> {
fn write_private_file(path: &std::path::Path, contents: &[u8], mode: u32) -> Result<()> {
fs::write(path, contents).with_context(|| format!("failed to write {}", path.display()))?;
set_permissions(path, mode)?;
Ok(())
}
#[cfg(unix)]
fn set_permissions(path: &Path, mode: u32) -> Result<()> {
fn set_permissions(path: &std::path::Path, mode: u32) -> Result<()> {
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(path, fs::Permissions::from_mode(mode))
.with_context(|| format!("failed to set permissions on {}", path.display()))?;
Ok(())
}
#[cfg(not(unix))]
fn set_permissions(_path: &Path, _mode: u32) -> Result<()> {
fn set_permissions(_path: &std::path::Path, _mode: u32) -> Result<()> {
Ok(())
}
fn blocked_text(reason: &str) -> Response {
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "text/plain")
.header("x-proxy-error", blocked_header_value(reason))
.body(Body::from(blocked_message(reason)))
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}
fn text_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(Body::from(body.to_string()))
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}
fn blocked_header_value(reason: &str) -> &'static str {
match reason {
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
"denied" => "blocked-by-denylist",
"method_not_allowed" => "blocked-by-method-policy",
"mitm_required" => "blocked-by-mitm-required",
_ => "blocked-by-policy",
}
}
fn blocked_message(reason: &str) -> &'static str {
match reason {
"method_not_allowed" => {
"Codex blocked this request: method not allowed in limited mode."
}
_ => "Codex blocked this request by network policy.",
}
}
}
#[cfg(not(feature = "mitm"))]
mod imp {
use crate::config::MitmConfig;
use crate::config::NetworkMode;
use crate::state::AppState;
use anyhow::Result;
use anyhow::anyhow;
use hyper::upgrade::Upgraded;
use rama::Context as RamaContext;
use rama::http::layer::upgrade::Upgraded;
use std::sync::Arc;
#[derive(Debug)]
@@ -652,7 +587,8 @@ mod imp {
}
pub async fn mitm_tunnel(
_stream: Upgraded,
_ctx: RamaContext<Arc<AppState>>,
_upgraded: Upgraded,
_host: &str,
_port: u16,
_mode: NetworkMode,

View File

@@ -1,11 +1,10 @@
use crate::config::NetworkMode;
use hyper::Method;
use std::net::IpAddr;
pub fn method_allowed(mode: NetworkMode, method: &Method) -> bool {
pub fn method_allowed(mode: NetworkMode, method: &str) -> bool {
match mode {
NetworkMode::Full => true,
NetworkMode::Limited => matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS),
NetworkMode::Limited => matches!(method, "GET" | "HEAD" | "OPTIONS"),
}
}

View File

@@ -2,26 +2,6 @@ use hyper::Body;
use hyper::Response;
use hyper::StatusCode;
use serde::Serialize;
use serde_json::json;
pub fn json_blocked(host: &str, reason: &str) -> Response<Body> {
let body = Body::from(json!({"status":"blocked","host":host,"reason":reason}).to_string());
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "application/json")
.header("x-proxy-error", blocked_header_value(reason))
.body(body)
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}
pub fn blocked_text(reason: &str) -> Response<Body> {
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "text/plain")
.header("x-proxy-error", blocked_header_value(reason))
.body(Body::from(blocked_message(reason).to_string()))
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}
pub fn text_response(status: StatusCode, body: &str) -> Response<Body> {
Response::builder()
@@ -42,24 +22,3 @@ pub fn json_response<T: Serialize>(value: &T) -> Response<Body> {
.body(Body::from(body))
.unwrap_or_else(|_| Response::new(Body::from("{}")))
}
fn blocked_header_value(reason: &str) -> &'static str {
match reason {
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
"denied" => "blocked-by-denylist",
"method_not_allowed" => "blocked-by-method-policy",
"mitm_required" => "blocked-by-mitm-required",
_ => "blocked-by-policy",
}
}
fn blocked_message(reason: &str) -> &'static str {
match reason {
"not_allowed" => "Codex blocked this request: domain not in allowlist.",
"not_allowed_local" => "Codex blocked this request: local addresses not allowed.",
"denied" => "Codex blocked this request: domain denied by policy.",
"method_not_allowed" => "Codex blocked this request: method not allowed in limited mode.",
"mitm_required" => "Codex blocked this request: MITM required for limited HTTPS.",
_ => "Codex blocked this request by network policy.",
}
}

View File

@@ -4,20 +4,30 @@ use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Result;
use anyhow::anyhow;
use rama::Context as RamaContext;
use rama::Service;
use rama::net::stream::SocketInfo;
use rama::proxy::socks5::Socks5Acceptor;
use rama::proxy::socks5::server::DefaultConnector;
use rama::service::service_fn;
use rama::tcp::client::Request as TcpRequest;
use rama::tcp::client::service::TcpConnector;
use rama::tcp::server::TcpListener;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::copy_bidirectional;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tracing::error;
use tracing::info;
use tracing::warn;
pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::bind(addr).await?;
let listener = TcpListener::build_with_state(state.clone())
.bind(addr)
.await
.map_err(|err| anyhow!("bind SOCKS5 proxy: {err}"))?;
info!(addr = %addr, "SOCKS5 proxy listening");
match state.network_mode().await {
Ok(NetworkMode::Limited) => {
info!(
@@ -30,163 +40,93 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
warn!(error = %err, "failed to read network mode");
}
}
loop {
let (stream, peer_addr) = listener.accept().await?;
let state = state.clone();
tokio::spawn(async move {
if let Err(err) = handle_socks5_client(stream, peer_addr, state).await {
warn!(error = %err, "SOCKS5 session ended with error");
let tcp_connector = TcpConnector::default();
let policy_tcp_connector =
service_fn(move |ctx: RamaContext<Arc<AppState>>, req: TcpRequest| {
let tcp_connector = tcp_connector.clone();
async move {
let app_state = ctx.state().clone();
let authority = req.authority().clone();
let host = normalize_host(&authority.host().to_string());
let port = authority.port();
let client = ctx
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
match app_state.network_mode().await {
Ok(NetworkMode::Limited) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"method_not_allowed".to_string(),
client.clone(),
None,
Some(NetworkMode::Limited),
"socks5".to_string(),
))
.await;
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"SOCKS blocked by method policy"
);
return Err(
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
);
}
Ok(NetworkMode::Full) => {}
Err(err) => {
error!(error = %err, "failed to evaluate method policy");
return Err(io::Error::new(io::ErrorKind::Other, "proxy error").into());
}
}
match app_state.host_blocked(&host).await {
Ok((true, reason)) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
None,
None,
"socks5".to_string(),
))
.await;
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
reason = %reason,
"SOCKS blocked"
);
return Err(
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
);
}
Ok((false, _)) => {
info!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
port = port,
"SOCKS allowed"
);
}
Err(err) => {
error!(error = %err, "failed to evaluate host");
return Err(io::Error::new(io::ErrorKind::Other, "proxy error").into());
}
}
tcp_connector.serve(ctx, req).await
}
});
}
}
async fn handle_socks5_client(
mut stream: TcpStream,
peer_addr: SocketAddr,
state: Arc<AppState>,
) -> Result<()> {
let mut header = [0u8; 2];
stream.read_exact(&mut header).await?;
if header[0] != 0x05 {
return Err(anyhow!("invalid SOCKS version"));
}
let nmethods = header[1] as usize;
let mut methods = vec![0u8; nmethods];
stream.read_exact(&mut methods).await?;
stream.write_all(&[0x05, 0x00]).await?;
let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector);
let socks_acceptor = Socks5Acceptor::new().with_connector(socks_connector);
let mut req_header = [0u8; 4];
stream.read_exact(&mut req_header).await?;
if req_header[0] != 0x05 {
return Err(anyhow!("invalid SOCKS request version"));
}
let cmd = req_header[1];
if cmd != 0x01 {
stream
.write_all(&[0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Err(anyhow!("unsupported SOCKS command"));
}
let atyp = req_header[3];
let host = match atyp {
0x01 => {
let mut addr = [0u8; 4];
stream.read_exact(&mut addr).await?;
format!("{}.{}.{}.{}", addr[0], addr[1], addr[2], addr[3])
}
0x03 => {
let mut len_buf = [0u8; 1];
stream.read_exact(&mut len_buf).await?;
let len = len_buf[0] as usize;
let mut domain = vec![0u8; len];
stream.read_exact(&mut domain).await?;
String::from_utf8_lossy(&domain).to_string()
}
0x04 => {
stream
.write_all(&[0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Err(anyhow!("ipv6 not supported"));
}
_ => {
stream
.write_all(&[0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Err(anyhow!("unknown address type"));
}
};
let mut port_buf = [0u8; 2];
stream.read_exact(&mut port_buf).await?;
let port = u16::from_be_bytes(port_buf);
let normalized_host = normalize_host(&host);
match state.network_mode().await {
Ok(NetworkMode::Limited) => {
let _ = state
.record_blocked(BlockedRequest::new(
normalized_host.clone(),
"method_not_allowed".to_string(),
Some(peer_addr.to_string()),
None,
Some(NetworkMode::Limited),
"socks5".to_string(),
))
.await;
warn!(
client = %peer_addr,
host = %normalized_host,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"SOCKS blocked by method policy"
);
stream
.write_all(&[0x05, 0x02, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
Ok(NetworkMode::Full) => {}
Err(err) => {
error!(error = %err, "failed to evaluate method policy");
stream
.write_all(&[0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
}
match state.host_blocked(&normalized_host).await {
Ok((true, reason)) => {
let _ = state
.record_blocked(BlockedRequest::new(
normalized_host.clone(),
reason.clone(),
Some(peer_addr.to_string()),
None,
None,
"socks5".to_string(),
))
.await;
warn!(client = %peer_addr, host = %normalized_host, reason = %reason, "SOCKS blocked");
stream
.write_all(&[0x05, 0x02, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
Ok((false, _)) => {
info!(
client = %peer_addr,
host = %normalized_host,
port = port,
"SOCKS allowed"
);
}
Err(err) => {
error!(error = %err, "failed to evaluate host");
stream
.write_all(&[0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
}
let target = format!("{host}:{port}");
let mut upstream = match TcpStream::connect(&target).await {
Ok(stream) => stream,
Err(err) => {
warn!(error = %err, "SOCKS connect failed");
stream
.write_all(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
};
stream
.write_all(&[0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
let _ = copy_bidirectional(&mut stream, &mut upstream).await;
listener.serve(socks_acceptor).await;
Ok(())
}

View File

@@ -10,9 +10,6 @@ use anyhow::anyhow;
use globset::GlobBuilder;
use globset::GlobSet;
use globset::GlobSetBuilder;
use hyper::Client;
use hyper::Method;
use hyper::client::HttpConnector;
use serde::Serialize;
use std::collections::HashSet;
use std::collections::VecDeque;
@@ -72,16 +69,13 @@ struct ConfigState {
#[derive(Clone)]
pub struct AppState {
pub(crate) client: Client<HttpConnector>,
state: Arc<RwLock<ConfigState>>,
}
impl AppState {
pub async fn new(cfg_path: PathBuf) -> Result<Self> {
let cfg_state = build_config_state(cfg_path)?;
let client = Client::new();
Ok(Self {
client,
state: Arc::new(RwLock::new(cfg_state)),
})
}
@@ -171,7 +165,7 @@ impl AppState {
.any(|p| p == path))
}
pub async fn method_allowed(&self, method: &Method) -> Result<bool> {
pub async fn method_allowed(&self, method: &str) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(method_allowed(guard.cfg.network_proxy.mode, method))