mirror of
https://github.com/openai/codex.git
synced 2026-04-30 01:16:54 +00:00
use rama instead of implementing our own proxy stack
This commit is contained in:
@@ -1,446 +1,450 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::mitm;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::responses::blocked_text;
|
||||
use crate::responses::json_blocked;
|
||||
use crate::responses::text_response;
|
||||
use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use hyper::Body;
|
||||
use hyper::Method;
|
||||
use hyper::Request;
|
||||
use hyper::Response;
|
||||
use hyper::Server;
|
||||
use hyper::StatusCode;
|
||||
use hyper::Uri;
|
||||
use hyper::body::to_bytes;
|
||||
use hyper::header::HOST;
|
||||
use hyper::header::HeaderName;
|
||||
use hyper::service::make_service_fn;
|
||||
use hyper::service::service_fn;
|
||||
use std::collections::HashSet;
|
||||
use anyhow::anyhow;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::Layer;
|
||||
use rama::Service;
|
||||
use rama::http::Body;
|
||||
use rama::http::Request;
|
||||
use rama::http::Response;
|
||||
use rama::http::StatusCode;
|
||||
use rama::http::client::EasyHttpWebClient;
|
||||
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
use rama::http::layer::upgrade::UpgradeLayer;
|
||||
use rama::http::layer::upgrade::Upgraded;
|
||||
use rama::http::matcher::MethodMatcher;
|
||||
use rama::http::server::HttpServer;
|
||||
use rama::net::http::RequestContext;
|
||||
use rama::net::proxy::ProxyTarget;
|
||||
use rama::net::stream::SocketInfo;
|
||||
use rama::service::service_fn;
|
||||
use rama::tcp::client::service::Forwarder;
|
||||
use rama::tcp::server::TcpListener;
|
||||
use serde_json::json;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::copy_bidirectional;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
type ContextState = Arc<AppState>;
|
||||
type ProxyContext = RamaContext<ContextState>;
|
||||
|
||||
pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
let make_svc = make_service_fn(move |conn: &hyper::server::conn::AddrStream| {
|
||||
let state = state.clone();
|
||||
let client_addr = conn.remote_addr();
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |req| {
|
||||
handle_proxy_request(req, state.clone(), client_addr)
|
||||
}))
|
||||
}
|
||||
});
|
||||
let server = Server::bind(&addr).serve(make_svc);
|
||||
let listener = TcpListener::build_with_state(state)
|
||||
.bind(addr)
|
||||
.await
|
||||
.map_err(|err| anyhow!("bind HTTP proxy: {err}"))?;
|
||||
|
||||
let http_service = HttpServer::auto(rama::rt::Executor::new()).service(
|
||||
(
|
||||
UpgradeLayer::new(
|
||||
MethodMatcher::CONNECT,
|
||||
service_fn(http_connect_accept),
|
||||
service_fn(http_connect_proxy),
|
||||
),
|
||||
RemoveResponseHeaderLayer::hop_by_hop(),
|
||||
RemoveRequestHeaderLayer::hop_by_hop(),
|
||||
)
|
||||
.into_layer(service_fn(http_plain_proxy)),
|
||||
);
|
||||
|
||||
info!(addr = %addr, "HTTP proxy listening");
|
||||
server.await?;
|
||||
|
||||
listener.serve(http_service).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_proxy_request(
|
||||
req: Request<Body>,
|
||||
state: Arc<AppState>,
|
||||
client_addr: SocketAddr,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
let response = if req.method() == Method::CONNECT {
|
||||
handle_connect(req, state, client_addr).await
|
||||
} else {
|
||||
handle_http_forward(req, state, client_addr).await
|
||||
async fn http_connect_accept(
|
||||
mut ctx: ProxyContext,
|
||||
req: Request,
|
||||
) -> Result<(Response, ProxyContext, Request), Response> {
|
||||
let authority = match ctx
|
||||
.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
|
||||
.map(|ctx| ctx.authority.clone())
|
||||
{
|
||||
Ok(authority) => authority,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "CONNECT missing authority");
|
||||
return Err(text_response(StatusCode::BAD_REQUEST, "missing authority"));
|
||||
}
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn handle_connect(
|
||||
req: Request<Body>,
|
||||
state: Arc<AppState>,
|
||||
client_addr: SocketAddr,
|
||||
) -> Response<Body> {
|
||||
let authority = match req.uri().authority() {
|
||||
Some(auth) => auth.as_str().to_string(),
|
||||
None => return text_response(StatusCode::BAD_REQUEST, "missing authority"),
|
||||
};
|
||||
let (authority_host, target_port) = split_authority(&authority);
|
||||
let host = normalize_host(&authority_host);
|
||||
let host = normalize_host(&authority.host().to_string());
|
||||
if host.is_empty() {
|
||||
return text_response(StatusCode::BAD_REQUEST, "invalid host");
|
||||
return Err(text_response(StatusCode::BAD_REQUEST, "invalid host"));
|
||||
}
|
||||
|
||||
match state.host_blocked(&host).await {
|
||||
let app_state = ctx.state().clone();
|
||||
let client = client_addr(&ctx);
|
||||
|
||||
match app_state.host_blocked(&host).await {
|
||||
Ok((true, reason)) => {
|
||||
let _ = state
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
Some(client_addr.to_string()),
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
None,
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(client = %client_addr, host = %host, reason = %reason, "CONNECT blocked");
|
||||
return blocked_text(&reason);
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
reason = %reason,
|
||||
"CONNECT blocked"
|
||||
);
|
||||
return Err(blocked_text(&reason));
|
||||
}
|
||||
Ok((false, _)) => {
|
||||
info!(client = %client_addr, host = %host, "CONNECT allowed");
|
||||
info!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
"CONNECT allowed"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate host");
|
||||
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
|
||||
let mode = match state.network_mode().await {
|
||||
let mode = match app_state.network_mode().await {
|
||||
Ok(mode) => mode,
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to read network mode");
|
||||
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
|
||||
let mitm_state = match state.mitm_state().await {
|
||||
let mitm_state = match app_state.mitm_state().await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to load MITM state");
|
||||
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
|
||||
if mode == NetworkMode::Limited && mitm_state.is_none() {
|
||||
let _ = state
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"mitm_required".to_string(),
|
||||
Some(client_addr.to_string()),
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
Some(NetworkMode::Limited),
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
client = %client_addr,
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"CONNECT blocked; MITM required for read-only HTTPS in limited mode"
|
||||
);
|
||||
return blocked_text("mitm_required");
|
||||
return Err(blocked_text("mitm_required"));
|
||||
}
|
||||
|
||||
let on_upgrade = hyper::upgrade::on(req);
|
||||
tokio::spawn(async move {
|
||||
match on_upgrade.await {
|
||||
Ok(upgraded) => {
|
||||
if let Some(mitm_state) = mitm_state {
|
||||
info!(client = %client_addr, host = %host, mode = ?mode, "CONNECT MITM enabled");
|
||||
if let Err(err) =
|
||||
mitm::mitm_tunnel(upgraded, &host, target_port, mode, mitm_state).await
|
||||
{
|
||||
warn!(error = %err, "MITM tunnel error");
|
||||
}
|
||||
return;
|
||||
}
|
||||
let mut upgraded = upgraded;
|
||||
match TcpStream::connect(&authority).await {
|
||||
Ok(mut server_stream) => {
|
||||
if let Err(err) =
|
||||
copy_bidirectional(&mut upgraded, &mut server_stream).await
|
||||
{
|
||||
warn!(error = %err, "tunnel error");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to connect to upstream");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => warn!(error = %err, "upgrade failed"),
|
||||
}
|
||||
});
|
||||
ctx.insert(ProxyTarget(authority));
|
||||
ctx.insert(mode);
|
||||
if let Some(mitm_state) = mitm_state {
|
||||
ctx.insert(mitm_state);
|
||||
}
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::empty())
|
||||
.unwrap_or_else(|_| Response::new(Body::empty()))
|
||||
Ok((
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::empty())
|
||||
.unwrap_or_else(|_| Response::new(Body::empty())),
|
||||
ctx,
|
||||
req,
|
||||
))
|
||||
}
|
||||
|
||||
async fn handle_http_forward(
|
||||
req: Request<Body>,
|
||||
state: Arc<AppState>,
|
||||
client_addr: SocketAddr,
|
||||
) -> Response<Body> {
|
||||
let (parts, body) = req.into_parts();
|
||||
let method_allowed = match state.method_allowed(&parts.method).await {
|
||||
async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(), Infallible> {
|
||||
let mode = ctx
|
||||
.get::<NetworkMode>()
|
||||
.copied()
|
||||
.unwrap_or(NetworkMode::Full);
|
||||
let authority = match ctx.get::<ProxyTarget>().map(|target| target.0.clone()) {
|
||||
Some(authority) => authority,
|
||||
None => {
|
||||
warn!("CONNECT missing proxy target");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let host = normalize_host(&authority.host().to_string());
|
||||
|
||||
if let Some(mitm_state) = ctx.get::<Arc<mitm::MitmState>>().cloned() {
|
||||
info!(host = %host, port = authority.port(), mode = ?mode, "CONNECT MITM enabled");
|
||||
if let Err(err) = mitm::mitm_tunnel(
|
||||
ctx,
|
||||
upgraded,
|
||||
host.as_str(),
|
||||
authority.port(),
|
||||
mode,
|
||||
mitm_state,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(error = %err, "MITM tunnel error");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let forwarder = Forwarder::ctx();
|
||||
if let Err(err) = forwarder.serve(ctx, upgraded).await {
|
||||
warn!(error = %err, "tunnel error");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Response, Infallible> {
|
||||
let app_state = ctx.state().clone();
|
||||
let client = client_addr(&ctx);
|
||||
|
||||
let method_allowed = match app_state.method_allowed(req.method().as_str()).await {
|
||||
Ok(allowed) => allowed,
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate method policy");
|
||||
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
let unix_socket = parts
|
||||
.headers
|
||||
|
||||
if let Some(socket_path) = req
|
||||
.headers()
|
||||
.get("x-unix-socket")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.to_string());
|
||||
|
||||
if let Some(socket_path) = unix_socket {
|
||||
.map(|v| v.to_string())
|
||||
{
|
||||
if !method_allowed {
|
||||
warn!(
|
||||
client = %client_addr,
|
||||
method = %parts.method,
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
method = %req.method(),
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"unix socket blocked by method policy"
|
||||
);
|
||||
return json_blocked("unix-socket", "method_not_allowed");
|
||||
return Ok(json_blocked("unix-socket", "method_not_allowed"));
|
||||
}
|
||||
|
||||
if !cfg!(target_os = "macos") {
|
||||
warn!(path = %socket_path, "unix socket proxy unsupported on this platform");
|
||||
return text_response(StatusCode::NOT_IMPLEMENTED, "unix sockets unsupported");
|
||||
return Ok(text_response(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"unix sockets unsupported",
|
||||
));
|
||||
}
|
||||
match state.is_unix_socket_allowed(&socket_path).await {
|
||||
|
||||
match app_state.is_unix_socket_allowed(&socket_path).await {
|
||||
Ok(true) => {
|
||||
info!(client = %client_addr, path = %socket_path, "unix socket allowed");
|
||||
match proxy_via_unix_socket(Request::from_parts(parts, body), &socket_path).await {
|
||||
Ok(resp) => return resp,
|
||||
info!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
path = %socket_path,
|
||||
"unix socket allowed"
|
||||
);
|
||||
match proxy_via_unix_socket(ctx, req, &socket_path).await {
|
||||
Ok(resp) => return Ok(resp),
|
||||
Err(err) => {
|
||||
warn!(error = %err, "unix socket proxy failed");
|
||||
return text_response(StatusCode::BAD_GATEWAY, "unix socket proxy failed");
|
||||
return Ok(text_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"unix socket proxy failed",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(false) => {
|
||||
warn!(client = %client_addr, path = %socket_path, "unix socket blocked");
|
||||
return json_blocked("unix-socket", "not_allowed");
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
path = %socket_path,
|
||||
"unix socket blocked"
|
||||
);
|
||||
return Ok(json_blocked("unix-socket", "not_allowed"));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "unix socket check failed");
|
||||
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let host_header = parts
|
||||
.headers
|
||||
.get(HOST)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.to_string())
|
||||
.or_else(|| parts.uri.authority().map(|a| a.as_str().to_string()));
|
||||
|
||||
let authority = match host_header {
|
||||
Some(h) => h,
|
||||
None => return text_response(StatusCode::BAD_REQUEST, "missing host"),
|
||||
let authority = match ctx
|
||||
.get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
|
||||
.map(|ctx| ctx.authority.clone())
|
||||
{
|
||||
Ok(authority) => authority,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "missing host");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
|
||||
}
|
||||
};
|
||||
let authority = authority.trim().to_string();
|
||||
let host = normalize_host(&authority);
|
||||
if host.is_empty() {
|
||||
return text_response(StatusCode::BAD_REQUEST, "invalid host");
|
||||
}
|
||||
let host = normalize_host(&authority.host().to_string());
|
||||
|
||||
match state.host_blocked(&host).await {
|
||||
match app_state.host_blocked(&host).await {
|
||||
Ok((true, reason)) => {
|
||||
let _ = state
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
Some(client_addr.to_string()),
|
||||
Some(parts.method.to_string()),
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
None,
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(client = %client_addr, host = %host, reason = %reason, "request blocked");
|
||||
return json_blocked(&host, &reason);
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
reason = %reason,
|
||||
"request blocked"
|
||||
);
|
||||
return Ok(json_blocked(&host, &reason));
|
||||
}
|
||||
Ok((false, _)) => {}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate host");
|
||||
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
|
||||
if !method_allowed {
|
||||
let _ = state
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
Some(client_addr.to_string()),
|
||||
Some(parts.method.to_string()),
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
Some(NetworkMode::Limited),
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
client = %client_addr,
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
method = %parts.method,
|
||||
method = %req.method(),
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"request blocked by method policy"
|
||||
);
|
||||
return json_blocked(&host, "method_not_allowed");
|
||||
return Ok(json_blocked(&host, "method_not_allowed"));
|
||||
}
|
||||
|
||||
info!(
|
||||
client = %client_addr,
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
method = %parts.method,
|
||||
method = %req.method(),
|
||||
"request allowed"
|
||||
);
|
||||
|
||||
let uri = match build_forward_uri(&authority, &parts.uri) {
|
||||
Ok(uri) => uri,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to build upstream uri");
|
||||
return text_response(StatusCode::BAD_REQUEST, "invalid uri");
|
||||
}
|
||||
};
|
||||
|
||||
let body_bytes = match to_bytes(body).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to read body");
|
||||
return text_response(StatusCode::BAD_GATEWAY, "failed to read body");
|
||||
}
|
||||
};
|
||||
|
||||
let mut builder = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(uri)
|
||||
.version(parts.version);
|
||||
let hop_headers = hop_by_hop_headers();
|
||||
for (name, value) in parts.headers.iter() {
|
||||
let name_str = name.as_str().to_ascii_lowercase();
|
||||
if hop_headers.contains(name_str.as_str())
|
||||
|| name == &HeaderName::from_static("x-unix-socket")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
|
||||
let forwarded_req = match builder.body(Body::from(body_bytes)) {
|
||||
Ok(req) => req,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to build request");
|
||||
return text_response(StatusCode::BAD_GATEWAY, "invalid request");
|
||||
}
|
||||
};
|
||||
|
||||
match state.client.request(forwarded_req).await {
|
||||
Ok(resp) => filter_response(resp),
|
||||
let client = EasyHttpWebClient::default();
|
||||
match client.serve(ctx, req).await {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(err) => {
|
||||
warn!(error = %err, "upstream request failed");
|
||||
text_response(StatusCode::BAD_GATEWAY, "upstream failure")
|
||||
Ok(text_response(StatusCode::BAD_GATEWAY, "upstream failure"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_forward_uri(authority: &str, uri: &Uri) -> Result<Uri> {
|
||||
let path = path_and_query(uri);
|
||||
let target = format!("http://{authority}{path}");
|
||||
Ok(target.parse()?)
|
||||
}
|
||||
|
||||
fn filter_response(resp: Response<Body>) -> Response<Body> {
|
||||
let mut builder = Response::builder().status(resp.status());
|
||||
let hop_headers = hop_by_hop_headers();
|
||||
for (name, value) in resp.headers().iter() {
|
||||
if hop_headers.contains(name.as_str().to_ascii_lowercase().as_str()) {
|
||||
continue;
|
||||
}
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
builder
|
||||
.body(resp.into_body())
|
||||
.unwrap_or_else(|_| Response::new(Body::from("proxy error")))
|
||||
}
|
||||
|
||||
fn path_and_query(uri: &Uri) -> String {
|
||||
uri.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("/")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn hop_by_hop_headers() -> HashSet<&'static str> {
|
||||
[
|
||||
"connection",
|
||||
"proxy-connection",
|
||||
"keep-alive",
|
||||
"proxy-authenticate",
|
||||
"proxy-authorization",
|
||||
"te",
|
||||
"trailer",
|
||||
"transfer-encoding",
|
||||
"upgrade",
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn split_authority(authority: &str) -> (String, u16) {
|
||||
if let Some(host) = authority.strip_prefix('[') {
|
||||
if let Some(end) = host.find(']') {
|
||||
let hostname = host[..end].to_string();
|
||||
let port = host[end + 1..]
|
||||
.strip_prefix(':')
|
||||
.and_then(|p| p.parse::<u16>().ok())
|
||||
.unwrap_or(443);
|
||||
return (hostname, port);
|
||||
}
|
||||
}
|
||||
let mut parts = authority.splitn(2, ':');
|
||||
let host = parts.next().unwrap_or("").to_string();
|
||||
let port = parts
|
||||
.next()
|
||||
.and_then(|p| p.parse::<u16>().ok())
|
||||
.unwrap_or(443);
|
||||
(host, port)
|
||||
}
|
||||
|
||||
async fn proxy_via_unix_socket(req: Request<Body>, socket_path: &str) -> Result<Response<Body>> {
|
||||
async fn proxy_via_unix_socket(
|
||||
ctx: ProxyContext,
|
||||
req: Request,
|
||||
socket_path: &str,
|
||||
) -> Result<Response> {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
use hyper::client::conn::Builder as ConnBuilder;
|
||||
use tokio::net::UnixStream;
|
||||
use rama::unix::client::UnixConnector;
|
||||
|
||||
let path = path_and_query(req.uri());
|
||||
let (parts, body) = req.into_parts();
|
||||
let body_bytes = to_bytes(body).await?;
|
||||
let mut builder = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(path)
|
||||
.version(parts.version);
|
||||
let hop_headers = hop_by_hop_headers();
|
||||
for (name, value) in parts.headers.iter() {
|
||||
let name_str = name.as_str().to_ascii_lowercase();
|
||||
if hop_headers.contains(name_str.as_str())
|
||||
|| name == &HeaderName::from_static("x-unix-socket")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
let req = builder.body(Body::from(body_bytes))?;
|
||||
let stream = UnixStream::connect(socket_path).await?;
|
||||
let (mut sender, conn) = ConnBuilder::new().handshake(stream).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = conn.await {
|
||||
warn!(error = %err, "unix socket connection error");
|
||||
}
|
||||
});
|
||||
Ok(sender.send_request(req).await?)
|
||||
let client = EasyHttpWebClient::builder()
|
||||
.with_custom_transport_connector(UnixConnector::fixed(socket_path))
|
||||
.without_tls_proxy_support()
|
||||
.without_proxy_support()
|
||||
.without_tls_support()
|
||||
.build();
|
||||
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let path = parts
|
||||
.uri
|
||||
.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("/");
|
||||
parts.uri = path
|
||||
.parse()
|
||||
.with_context(|| format!("invalid unix socket request path: {path}"))?;
|
||||
parts.headers.remove("x-unix-socket");
|
||||
|
||||
let req = Request::from_parts(parts, body);
|
||||
Ok(client.serve(ctx, req).await?)
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _ = ctx;
|
||||
let _ = req;
|
||||
let _ = socket_path;
|
||||
Err(anyhow::anyhow!("unix sockets not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
fn client_addr(ctx: &ProxyContext) -> Option<String> {
|
||||
ctx.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string())
|
||||
}
|
||||
|
||||
fn json_blocked(host: &str, reason: &str) -> Response {
|
||||
let body = Body::from(json!({"status":"blocked","host":host,"reason":reason}).to_string());
|
||||
Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "application/json")
|
||||
.header("x-proxy-error", blocked_header_value(reason))
|
||||
.body(body)
|
||||
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
|
||||
}
|
||||
|
||||
fn blocked_text(reason: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "text/plain")
|
||||
.header("x-proxy-error", blocked_header_value(reason))
|
||||
.body(Body::from(blocked_message(reason)))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
|
||||
}
|
||||
|
||||
fn text_response(status: StatusCode, body: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
|
||||
}
|
||||
|
||||
fn blocked_header_value(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
|
||||
"denied" => "blocked-by-denylist",
|
||||
"method_not_allowed" => "blocked-by-method-policy",
|
||||
"mitm_required" => "blocked-by-mitm-required",
|
||||
_ => "blocked-by-policy",
|
||||
}
|
||||
}
|
||||
|
||||
fn blocked_message(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
"not_allowed" => "Codex blocked this request: domain not in allowlist.",
|
||||
"not_allowed_local" => "Codex blocked this request: local addresses not allowed.",
|
||||
"denied" => "Codex blocked this request: domain denied by policy.",
|
||||
"method_not_allowed" => "Codex blocked this request: method not allowed in limited mode.",
|
||||
"mitm_required" => "Codex blocked this request: MITM required for limited HTTPS.",
|
||||
_ => "Codex blocked this request by network policy.",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,21 +4,46 @@ mod imp {
|
||||
use crate::config::NetworkMode;
|
||||
use crate::policy::method_allowed;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::responses::text_response;
|
||||
use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use hyper::Body;
|
||||
use hyper::Method;
|
||||
use hyper::Request;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use hyper::Uri;
|
||||
use hyper::Version;
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::header::HOST;
|
||||
use hyper::server::conn::Http;
|
||||
use hyper::service::service_fn;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::Layer;
|
||||
use rama::Service;
|
||||
use rama::bytes::Bytes;
|
||||
use rama::error::BoxError;
|
||||
use rama::error::OpaqueError;
|
||||
use rama::futures::stream::Stream;
|
||||
use rama::http::Body;
|
||||
use rama::http::HeaderValue;
|
||||
use rama::http::Request;
|
||||
use rama::http::Response;
|
||||
use rama::http::StatusCode;
|
||||
use rama::http::Uri;
|
||||
use rama::http::header::HOST;
|
||||
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
use rama::http::layer::upgrade::Upgraded;
|
||||
use rama::http::server::HttpServer;
|
||||
use rama::net::proxy::ProxyTarget;
|
||||
use rama::net::stream::SocketInfo;
|
||||
use rama::service::service_fn;
|
||||
use rama::tls::rustls::dep::pemfile;
|
||||
use rama::tls::rustls::server::TlsAcceptorData;
|
||||
use rama::tls::rustls::server::TlsAcceptorDataBuilder;
|
||||
use rama::tls::rustls::server::TlsAcceptorLayer;
|
||||
use std::fs;
|
||||
use std::io::BufReader;
|
||||
use std::net::IpAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Context as TaskContext;
|
||||
use std::task::Poll;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use rcgen::BasicConstraints;
|
||||
use rcgen::Certificate;
|
||||
use rcgen::CertificateParams;
|
||||
@@ -29,62 +54,11 @@ mod imp {
|
||||
use rcgen::KeyPair;
|
||||
use rcgen::KeyUsagePurpose;
|
||||
use rcgen::SanType;
|
||||
use rustls::Certificate as RustlsCertificate;
|
||||
use rustls::ClientConfig;
|
||||
use rustls::PrivateKey;
|
||||
use rustls::RootCertStore;
|
||||
use rustls::ServerConfig;
|
||||
use std::collections::HashSet;
|
||||
use std::convert::Infallible;
|
||||
use std::fs;
|
||||
use std::io::Cursor;
|
||||
use std::net::IpAddr;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum MitmProtocol {
|
||||
Http1,
|
||||
Http2,
|
||||
}
|
||||
|
||||
struct MitmTarget {
|
||||
host: String,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
impl MitmTarget {
|
||||
fn authority(&self) -> String {
|
||||
if self.port == 443 {
|
||||
self.host.clone()
|
||||
} else {
|
||||
format!("{}:{}", self.host, self.port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RequestLogContext {
|
||||
host: String,
|
||||
method: Method,
|
||||
path: String,
|
||||
}
|
||||
|
||||
struct ResponseLogContext {
|
||||
host: String,
|
||||
method: Method,
|
||||
path: String,
|
||||
status: StatusCode,
|
||||
}
|
||||
|
||||
pub struct MitmState {
|
||||
ca_key: KeyPair,
|
||||
ca_cert: Certificate,
|
||||
client_config: Arc<ClientConfig>,
|
||||
upstream: rama::service::BoxService<Arc<AppState>, Request, Response, OpaqueError>,
|
||||
inspect: bool,
|
||||
max_body_bytes: usize,
|
||||
}
|
||||
@@ -98,30 +72,44 @@ mod imp {
|
||||
let ca_cert = ca_params
|
||||
.self_signed(&ca_key)
|
||||
.context("failed to reconstruct CA cert")?;
|
||||
let client_config = build_client_config()?;
|
||||
|
||||
let tls_config = rama::tls::rustls::client::TlsConnectorData::new_http_auto()
|
||||
.context("create upstream TLS config")?;
|
||||
let upstream = rama::http::client::EasyHttpWebClient::builder()
|
||||
.with_default_transport_connector()
|
||||
.without_tls_proxy_support()
|
||||
.without_proxy_support()
|
||||
.with_tls_support_using_rustls(Some(tls_config))
|
||||
.build()
|
||||
.boxed();
|
||||
|
||||
Ok(Self {
|
||||
ca_key,
|
||||
ca_cert,
|
||||
client_config,
|
||||
upstream,
|
||||
inspect: cfg.inspect,
|
||||
max_body_bytes: cfg.max_body_bytes,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn server_config_for_host(&self, host: &str) -> Result<Arc<ServerConfig>> {
|
||||
let (certs, key) = issue_host_certificate(host, &self.ca_cert, &self.ca_key)?;
|
||||
let mut config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)
|
||||
.context("failed to build server TLS config")?;
|
||||
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
|
||||
let (cert_pem, key_pem) =
|
||||
issue_host_certificate_pem(host, &self.ca_cert, &self.ca_key)?;
|
||||
let cert_chain = pemfile::certs(&mut BufReader::new(cert_pem.as_bytes()))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.context("failed to parse host cert PEM")?;
|
||||
if cert_chain.is_empty() {
|
||||
return Err(anyhow!("no certificates found"));
|
||||
}
|
||||
|
||||
pub fn client_config(&self) -> Arc<ClientConfig> {
|
||||
Arc::clone(&self.client_config)
|
||||
let key_der = pemfile::private_key(&mut BufReader::new(key_pem.as_bytes()))
|
||||
.context("failed to parse host key PEM")?
|
||||
.context("no private key found")?;
|
||||
|
||||
Ok(TlsAcceptorDataBuilder::new(cert_chain, key_der)
|
||||
.context("failed to build rustls acceptor config")?
|
||||
.with_alpn_protocols_http_auto()
|
||||
.build())
|
||||
}
|
||||
|
||||
pub fn inspect_enabled(&self) -> bool {
|
||||
@@ -134,97 +122,87 @@ mod imp {
|
||||
}
|
||||
|
||||
pub async fn mitm_tunnel(
|
||||
stream: hyper::upgrade::Upgraded,
|
||||
mut ctx: RamaContext<Arc<AppState>>,
|
||||
upgraded: Upgraded,
|
||||
host: &str,
|
||||
port: u16,
|
||||
_port: u16,
|
||||
mode: NetworkMode,
|
||||
state: Arc<MitmState>,
|
||||
) -> Result<()> {
|
||||
let server_config = state.server_config_for_host(host)?;
|
||||
let acceptor = TlsAcceptor::from(server_config);
|
||||
let tls_stream = acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.context("client TLS handshake failed")?;
|
||||
let protocol = match tls_stream.get_ref().1.alpn_protocol() {
|
||||
Some(proto) if proto == b"h2" => MitmProtocol::Http2,
|
||||
_ => MitmProtocol::Http1,
|
||||
};
|
||||
info!(
|
||||
host = %host,
|
||||
port = port,
|
||||
protocol = ?protocol,
|
||||
mode = ?mode,
|
||||
inspect = state.inspect_enabled(),
|
||||
max_body_bytes = state.max_body_bytes(),
|
||||
"MITM TLS established"
|
||||
// Ensure the MITM state is available for the per-request handler.
|
||||
ctx.insert(state.clone());
|
||||
ctx.insert(mode);
|
||||
|
||||
let acceptor_data = state.tls_acceptor_data_for_host(host)?;
|
||||
let http_service = HttpServer::auto(ctx.executor().clone()).service(
|
||||
(
|
||||
RemoveResponseHeaderLayer::hop_by_hop(),
|
||||
RemoveRequestHeaderLayer::hop_by_hop(),
|
||||
)
|
||||
.into_layer(service_fn(handle_mitm_request)),
|
||||
);
|
||||
|
||||
let target = Arc::new(MitmTarget {
|
||||
host: host.to_string(),
|
||||
port,
|
||||
});
|
||||
let service = {
|
||||
let state = state.clone();
|
||||
let target = target.clone();
|
||||
service_fn(move |req| handle_mitm_request(req, target.clone(), mode, state.clone()))
|
||||
};
|
||||
let https_service = TlsAcceptorLayer::new(acceptor_data)
|
||||
.with_store_client_hello(true)
|
||||
.into_layer(http_service);
|
||||
|
||||
let mut http = Http::new();
|
||||
match protocol {
|
||||
MitmProtocol::Http2 => {
|
||||
http.http2_only(true);
|
||||
}
|
||||
MitmProtocol::Http1 => {
|
||||
http.http1_only(true);
|
||||
}
|
||||
}
|
||||
http.serve_connection(tls_stream, service)
|
||||
https_service
|
||||
.serve(ctx, upgraded)
|
||||
.await
|
||||
.context("MITM HTTP handling failed")?;
|
||||
.map_err(|err| anyhow!("MITM serve error: {err}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_mitm_request(
|
||||
req: Request<Body>,
|
||||
target: Arc<MitmTarget>,
|
||||
mode: NetworkMode,
|
||||
state: Arc<MitmState>,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
let response = match forward_request(req, target.as_ref(), mode, state.as_ref()).await {
|
||||
ctx: RamaContext<Arc<AppState>>,
|
||||
req: Request,
|
||||
) -> Result<Response, std::convert::Infallible> {
|
||||
let response = match forward_request(ctx, req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
warn!(error = %err, host = %target.host, "MITM upstream request failed");
|
||||
warn!(error = %err, "MITM upstream request failed");
|
||||
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
|
||||
}
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn forward_request(
|
||||
req: Request<Body>,
|
||||
target: &MitmTarget,
|
||||
mode: NetworkMode,
|
||||
state: &MitmState,
|
||||
) -> Result<Response<Body>> {
|
||||
if req.method() == Method::CONNECT {
|
||||
async fn forward_request(ctx: RamaContext<Arc<AppState>>, req: Request) -> Result<Response> {
|
||||
let target = ctx
|
||||
.get::<ProxyTarget>()
|
||||
.context("missing proxy target")?
|
||||
.0
|
||||
.clone();
|
||||
|
||||
let target_host = normalize_host(&target.host().to_string());
|
||||
let target_port = target.port();
|
||||
let mode = ctx
|
||||
.get::<NetworkMode>()
|
||||
.copied()
|
||||
.unwrap_or(NetworkMode::Full);
|
||||
let mitm = ctx
|
||||
.get::<Arc<MitmState>>()
|
||||
.cloned()
|
||||
.context("missing MITM state")?;
|
||||
|
||||
if req.method().as_str() == "CONNECT" {
|
||||
return Ok(text_response(
|
||||
StatusCode::METHOD_NOT_ALLOWED,
|
||||
"CONNECT not supported inside MITM",
|
||||
));
|
||||
}
|
||||
|
||||
let (parts, body) = req.into_parts();
|
||||
let request_version = parts.version;
|
||||
let method = parts.method.clone();
|
||||
let inspect = state.inspect_enabled();
|
||||
let max_body_bytes = state.max_body_bytes();
|
||||
let method = req.method().as_str().to_string();
|
||||
let path = path_and_query(req.uri());
|
||||
let client = ctx
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string());
|
||||
|
||||
if let Some(request_host) = extract_request_host(&parts) {
|
||||
if let Some(request_host) = extract_request_host(&req) {
|
||||
let normalized = normalize_host(&request_host);
|
||||
if !normalized.is_empty() && normalized != target.host {
|
||||
if !normalized.is_empty() && normalized != target_host {
|
||||
warn!(
|
||||
target = %target.host,
|
||||
target = %target_host,
|
||||
request_host = %normalized,
|
||||
"MITM host mismatch"
|
||||
);
|
||||
@@ -232,175 +210,143 @@ mod imp {
|
||||
}
|
||||
}
|
||||
|
||||
let path = path_and_query(&parts.uri);
|
||||
let uri = build_origin_form_uri(&path)?;
|
||||
let authority = target.authority();
|
||||
|
||||
if !method_allowed(mode, &method) {
|
||||
if !method_allowed(mode, method.as_str()) {
|
||||
let _ = ctx
|
||||
.state()
|
||||
.record_blocked(BlockedRequest::new(
|
||||
target_host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
client.clone(),
|
||||
Some(method.clone()),
|
||||
Some(NetworkMode::Limited),
|
||||
"https".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
host = %authority,
|
||||
host = %target_host,
|
||||
method = %method,
|
||||
path = %path,
|
||||
mode = ?mode,
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"MITM blocked by method policy"
|
||||
);
|
||||
return Ok(text_response(StatusCode::FORBIDDEN, "method not allowed"));
|
||||
return Ok(blocked_text("method_not_allowed"));
|
||||
}
|
||||
|
||||
let mut builder = Request::builder()
|
||||
.method(method.clone())
|
||||
.uri(uri)
|
||||
.version(Version::HTTP_11);
|
||||
|
||||
let hop_headers = hop_by_hop_headers();
|
||||
for (name, value) in parts.headers.iter() {
|
||||
let name_str = name.as_str().to_ascii_lowercase();
|
||||
if hop_headers.contains(name_str.as_str()) || name == &HOST {
|
||||
continue;
|
||||
}
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
builder = builder.header(HOST, authority.as_str());
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let authority = authority_header_value(&target_host, target_port);
|
||||
parts.uri = build_https_uri(&authority, &path)?;
|
||||
parts
|
||||
.headers
|
||||
.insert(HOST, HeaderValue::from_str(&authority)?);
|
||||
|
||||
let inspect = mitm.inspect_enabled();
|
||||
let max_body_bytes = mitm.max_body_bytes();
|
||||
let body = if inspect {
|
||||
let (tx, out_body) = Body::channel();
|
||||
let ctx = RequestLogContext {
|
||||
host: authority.clone(),
|
||||
method: method.clone(),
|
||||
path: path.clone(),
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
stream_body(body, tx, max_body_bytes, ctx).await;
|
||||
});
|
||||
out_body
|
||||
inspect_body(
|
||||
body,
|
||||
max_body_bytes,
|
||||
RequestLogContext {
|
||||
host: authority.clone(),
|
||||
method: method.clone(),
|
||||
path: path.clone(),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
body
|
||||
};
|
||||
|
||||
let upstream_req = builder
|
||||
.body(body)
|
||||
.context("failed to build upstream request")?;
|
||||
let upstream_resp = send_upstream_request(upstream_req, target, state).await?;
|
||||
|
||||
let upstream_req = Request::from_parts(parts, body);
|
||||
let upstream_resp = mitm.upstream.serve(ctx, upstream_req).await?;
|
||||
respond_with_inspection(
|
||||
upstream_resp,
|
||||
request_version,
|
||||
inspect,
|
||||
max_body_bytes,
|
||||
&method,
|
||||
&path,
|
||||
&authority,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_upstream_request(
|
||||
req: Request<Body>,
|
||||
target: &MitmTarget,
|
||||
state: &MitmState,
|
||||
) -> Result<Response<Body>> {
|
||||
let upstream = TcpStream::connect((target.host.as_str(), target.port))
|
||||
.await
|
||||
.context("failed to connect to upstream")?;
|
||||
let server_name = match target.host.parse::<IpAddr>() {
|
||||
Ok(ip) => rustls::ServerName::IpAddress(ip),
|
||||
Err(_) => rustls::ServerName::try_from(target.host.as_str())
|
||||
.map_err(|_| anyhow!("invalid server name"))?,
|
||||
};
|
||||
let connector = TlsConnector::from(state.client_config());
|
||||
let tls_stream = connector
|
||||
.connect(server_name, upstream)
|
||||
.await
|
||||
.context("upstream TLS handshake failed")?;
|
||||
let (mut sender, conn) = hyper::client::conn::Builder::new()
|
||||
.handshake(tls_stream)
|
||||
.await
|
||||
.context("upstream HTTP handshake failed")?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = conn.await {
|
||||
warn!(error = %err, "MITM upstream connection error");
|
||||
}
|
||||
});
|
||||
let resp = sender
|
||||
.send_request(req)
|
||||
.await
|
||||
.context("upstream request failed")?;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn respond_with_inspection(
|
||||
resp: Response<Body>,
|
||||
request_version: Version,
|
||||
fn respond_with_inspection(
|
||||
resp: Response,
|
||||
inspect: bool,
|
||||
max_body_bytes: usize,
|
||||
method: &Method,
|
||||
method: &str,
|
||||
path: &str,
|
||||
authority: &str,
|
||||
) -> Result<Response<Body>> {
|
||||
let (parts, body) = resp.into_parts();
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(parts.status)
|
||||
.version(request_version);
|
||||
let hop_headers = hop_by_hop_headers();
|
||||
for (name, value) in parts.headers.iter() {
|
||||
if hop_headers.contains(name.as_str().to_ascii_lowercase().as_str()) {
|
||||
continue;
|
||||
}
|
||||
builder = builder.header(name, value);
|
||||
) -> Result<Response> {
|
||||
if !inspect {
|
||||
return Ok(resp);
|
||||
}
|
||||
let body = if inspect {
|
||||
let (tx, out_body) = Body::channel();
|
||||
let ctx = ResponseLogContext {
|
||||
|
||||
let (parts, body) = resp.into_parts();
|
||||
let body = inspect_body(
|
||||
body,
|
||||
max_body_bytes,
|
||||
ResponseLogContext {
|
||||
host: authority.to_string(),
|
||||
method: method.clone(),
|
||||
method: method.to_string(),
|
||||
path: path.to_string(),
|
||||
status: parts.status,
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
stream_body(body, tx, max_body_bytes, ctx).await;
|
||||
});
|
||||
out_body
|
||||
} else {
|
||||
body
|
||||
};
|
||||
Ok(builder
|
||||
.body(body)
|
||||
.unwrap_or_else(|_| Response::new(Body::from("proxy error"))))
|
||||
},
|
||||
);
|
||||
Ok(Response::from_parts(parts, body))
|
||||
}
|
||||
|
||||
async fn stream_body<T>(
|
||||
mut body: Body,
|
||||
mut tx: hyper::body::Sender,
|
||||
fn inspect_body<T: BodyLoggable + Send + 'static>(
|
||||
body: Body,
|
||||
max_body_bytes: usize,
|
||||
ctx: T,
|
||||
) where
|
||||
T: BodyLoggable,
|
||||
{
|
||||
let mut len: usize = 0;
|
||||
let mut truncated = false;
|
||||
while let Some(chunk) = body.data().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
len = len.saturating_add(bytes.len());
|
||||
if len > max_body_bytes {
|
||||
truncated = true;
|
||||
}
|
||||
if tx.send_data(bytes).await.is_err() {
|
||||
break;
|
||||
}
|
||||
) -> Body {
|
||||
Body::from_stream(InspectStream {
|
||||
inner: Box::pin(body.into_data_stream()),
|
||||
ctx: Some(Box::new(ctx)),
|
||||
len: 0,
|
||||
max_body_bytes,
|
||||
})
|
||||
}
|
||||
|
||||
struct InspectStream<T> {
|
||||
inner: Pin<Box<rama::http::BodyDataStream>>,
|
||||
ctx: Option<Box<T>>,
|
||||
len: usize,
|
||||
max_body_bytes: usize,
|
||||
}
|
||||
|
||||
impl<T: BodyLoggable> Stream for InspectStream<T> {
|
||||
type Item = Result<Bytes, BoxError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
match this.inner.as_mut().poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(bytes))) => {
|
||||
this.len = this.len.saturating_add(bytes.len());
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "MITM body stream error");
|
||||
break;
|
||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(None) => {
|
||||
if let Some(ctx) = this.ctx.take() {
|
||||
ctx.log(this.len, this.len > this.max_body_bytes);
|
||||
}
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
if let Ok(Some(trailers)) = body.trailers().await {
|
||||
let _ = tx.send_trailers(trailers).await;
|
||||
}
|
||||
ctx.log(len, truncated);
|
||||
}
|
||||
|
||||
struct RequestLogContext {
|
||||
host: String,
|
||||
method: String,
|
||||
path: String,
|
||||
}
|
||||
|
||||
struct ResponseLogContext {
|
||||
host: String,
|
||||
method: String,
|
||||
path: String,
|
||||
status: StatusCode,
|
||||
}
|
||||
|
||||
trait BodyLoggable {
|
||||
@@ -434,13 +380,32 @@ mod imp {
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_request_host(parts: &hyper::http::request::Parts) -> Option<String> {
|
||||
parts
|
||||
.headers
|
||||
fn extract_request_host(req: &Request) -> Option<String> {
|
||||
req.headers()
|
||||
.get(HOST)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.to_string())
|
||||
.or_else(|| parts.uri.authority().map(|a| a.as_str().to_string()))
|
||||
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
|
||||
}
|
||||
|
||||
fn authority_header_value(host: &str, port: u16) -> String {
|
||||
// Host header / URI authority formatting.
|
||||
if host.contains(':') {
|
||||
if port == 443 {
|
||||
format!("[{host}]")
|
||||
} else {
|
||||
format!("[{host}]:{port}")
|
||||
}
|
||||
} else if port == 443 {
|
||||
host.to_string()
|
||||
} else {
|
||||
format!("{host}:{port}")
|
||||
}
|
||||
}
|
||||
|
||||
fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
|
||||
let target = format!("https://{authority}{path}");
|
||||
Ok(target.parse()?)
|
||||
}
|
||||
|
||||
fn path_and_query(uri: &Uri) -> String {
|
||||
@@ -450,51 +415,11 @@ mod imp {
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn build_origin_form_uri(path: &str) -> Result<Uri> {
|
||||
path.parse().context("invalid request path")
|
||||
}
|
||||
|
||||
fn hop_by_hop_headers() -> HashSet<&'static str> {
|
||||
[
|
||||
"connection",
|
||||
"proxy-connection",
|
||||
"keep-alive",
|
||||
"proxy-authenticate",
|
||||
"proxy-authorization",
|
||||
"te",
|
||||
"trailer",
|
||||
"transfer-encoding",
|
||||
"upgrade",
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build_client_config() -> Result<Arc<ClientConfig>> {
|
||||
let mut roots = RootCertStore::empty();
|
||||
let certs = rustls_native_certs::load_native_certs()
|
||||
.map_err(|err| anyhow!("failed to load native certs: {err}"))?;
|
||||
for cert in certs {
|
||||
if roots.add(&RustlsCertificate(cert.0)).is_err() {
|
||||
warn!("skipping invalid root cert");
|
||||
}
|
||||
}
|
||||
if roots.is_empty() {
|
||||
return Err(anyhow!("no root certificates available"));
|
||||
}
|
||||
let mut config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
|
||||
fn issue_host_certificate(
|
||||
fn issue_host_certificate_pem(
|
||||
host: &str,
|
||||
ca_cert: &Certificate,
|
||||
ca_key: &KeyPair,
|
||||
) -> Result<(Vec<RustlsCertificate>, PrivateKey)> {
|
||||
) -> Result<(String, String)> {
|
||||
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
let mut params = CertificateParams::new(Vec::new())
|
||||
.map_err(|err| anyhow!("failed to create cert params: {err}"))?;
|
||||
@@ -504,6 +429,7 @@ mod imp {
|
||||
CertificateParams::new(vec![host.to_string()])
|
||||
.map_err(|err| anyhow!("failed to create cert params: {err}"))?
|
||||
};
|
||||
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
@@ -516,16 +442,13 @@ mod imp {
|
||||
.signed_by(&key_pair, ca_cert, ca_key)
|
||||
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
|
||||
|
||||
let cert_pem = cert.pem();
|
||||
let key_pem = key_pair.serialize_pem();
|
||||
let certs = certs_from_pem(&cert_pem)?;
|
||||
let key = private_key_from_pem(&key_pem)?;
|
||||
Ok((certs, key))
|
||||
Ok((cert.pem(), key_pair.serialize_pem()))
|
||||
}
|
||||
|
||||
fn load_or_create_ca(cfg: &MitmConfig) -> Result<(String, String)> {
|
||||
let cert_path = &cfg.ca_cert_path;
|
||||
let key_path = &cfg.ca_key_path;
|
||||
|
||||
if cert_path.exists() || key_path.exists() {
|
||||
if !cert_path.exists() || !key_path.exists() {
|
||||
return Err(anyhow!("both ca_cert_path and ca_key_path must exist"));
|
||||
@@ -574,63 +497,75 @@ mod imp {
|
||||
let cert = params
|
||||
.self_signed(&key_pair)
|
||||
.map_err(|err| anyhow!("failed to generate CA cert: {err}"))?;
|
||||
let cert_pem = cert.pem();
|
||||
let key_pem = key_pair.serialize_pem();
|
||||
Ok((cert_pem, key_pem))
|
||||
Ok((cert.pem(), key_pair.serialize_pem()))
|
||||
}
|
||||
|
||||
fn certs_from_pem(pem: &str) -> Result<Vec<RustlsCertificate>> {
|
||||
let mut reader = Cursor::new(pem);
|
||||
let certs = rustls_pemfile::certs(&mut reader).context("failed to parse cert PEM")?;
|
||||
if certs.is_empty() {
|
||||
return Err(anyhow!("no certificates found"));
|
||||
}
|
||||
Ok(certs.into_iter().map(RustlsCertificate).collect())
|
||||
}
|
||||
|
||||
fn private_key_from_pem(pem: &str) -> Result<PrivateKey> {
|
||||
let mut reader = Cursor::new(pem);
|
||||
let mut keys =
|
||||
rustls_pemfile::pkcs8_private_keys(&mut reader).context("failed to parse pkcs8 key")?;
|
||||
if let Some(key) = keys.pop() {
|
||||
return Ok(PrivateKey(key));
|
||||
}
|
||||
let mut reader = Cursor::new(pem);
|
||||
let mut keys =
|
||||
rustls_pemfile::rsa_private_keys(&mut reader).context("failed to parse rsa key")?;
|
||||
if let Some(key) = keys.pop() {
|
||||
return Ok(PrivateKey(key));
|
||||
}
|
||||
Err(anyhow!("no private key found"))
|
||||
}
|
||||
|
||||
fn write_private_file(path: &Path, contents: &[u8], mode: u32) -> Result<()> {
|
||||
fn write_private_file(path: &std::path::Path, contents: &[u8], mode: u32) -> Result<()> {
|
||||
fs::write(path, contents).with_context(|| format!("failed to write {}", path.display()))?;
|
||||
set_permissions(path, mode)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn set_permissions(path: &Path, mode: u32) -> Result<()> {
|
||||
fn set_permissions(path: &std::path::Path, mode: u32) -> Result<()> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
fs::set_permissions(path, fs::Permissions::from_mode(mode))
|
||||
.with_context(|| format!("failed to set permissions on {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn set_permissions(_path: &Path, _mode: u32) -> Result<()> {
|
||||
fn set_permissions(_path: &std::path::Path, _mode: u32) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn blocked_text(reason: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "text/plain")
|
||||
.header("x-proxy-error", blocked_header_value(reason))
|
||||
.body(Body::from(blocked_message(reason)))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
|
||||
}
|
||||
|
||||
fn text_response(status: StatusCode, body: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
|
||||
}
|
||||
|
||||
fn blocked_header_value(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
|
||||
"denied" => "blocked-by-denylist",
|
||||
"method_not_allowed" => "blocked-by-method-policy",
|
||||
"mitm_required" => "blocked-by-mitm-required",
|
||||
_ => "blocked-by-policy",
|
||||
}
|
||||
}
|
||||
|
||||
fn blocked_message(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
"method_not_allowed" => {
|
||||
"Codex blocked this request: method not allowed in limited mode."
|
||||
}
|
||||
_ => "Codex blocked this request by network policy.",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "mitm"))]
|
||||
mod imp {
|
||||
use crate::config::MitmConfig;
|
||||
use crate::config::NetworkMode;
|
||||
use crate::state::AppState;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::http::layer::upgrade::Upgraded;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -652,7 +587,8 @@ mod imp {
|
||||
}
|
||||
|
||||
pub async fn mitm_tunnel(
|
||||
_stream: Upgraded,
|
||||
_ctx: RamaContext<Arc<AppState>>,
|
||||
_upgraded: Upgraded,
|
||||
_host: &str,
|
||||
_port: u16,
|
||||
_mode: NetworkMode,
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use crate::config::NetworkMode;
|
||||
use hyper::Method;
|
||||
use std::net::IpAddr;
|
||||
|
||||
pub fn method_allowed(mode: NetworkMode, method: &Method) -> bool {
|
||||
pub fn method_allowed(mode: NetworkMode, method: &str) -> bool {
|
||||
match mode {
|
||||
NetworkMode::Full => true,
|
||||
NetworkMode::Limited => matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS),
|
||||
NetworkMode::Limited => matches!(method, "GET" | "HEAD" | "OPTIONS"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,26 +2,6 @@ use hyper::Body;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
|
||||
pub fn json_blocked(host: &str, reason: &str) -> Response<Body> {
|
||||
let body = Body::from(json!({"status":"blocked","host":host,"reason":reason}).to_string());
|
||||
Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "application/json")
|
||||
.header("x-proxy-error", blocked_header_value(reason))
|
||||
.body(body)
|
||||
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
|
||||
}
|
||||
|
||||
pub fn blocked_text(reason: &str) -> Response<Body> {
|
||||
Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "text/plain")
|
||||
.header("x-proxy-error", blocked_header_value(reason))
|
||||
.body(Body::from(blocked_message(reason).to_string()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
|
||||
}
|
||||
|
||||
pub fn text_response(status: StatusCode, body: &str) -> Response<Body> {
|
||||
Response::builder()
|
||||
@@ -42,24 +22,3 @@ pub fn json_response<T: Serialize>(value: &T) -> Response<Body> {
|
||||
.body(Body::from(body))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("{}")))
|
||||
}
|
||||
|
||||
fn blocked_header_value(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
|
||||
"denied" => "blocked-by-denylist",
|
||||
"method_not_allowed" => "blocked-by-method-policy",
|
||||
"mitm_required" => "blocked-by-mitm-required",
|
||||
_ => "blocked-by-policy",
|
||||
}
|
||||
}
|
||||
|
||||
fn blocked_message(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
"not_allowed" => "Codex blocked this request: domain not in allowlist.",
|
||||
"not_allowed_local" => "Codex blocked this request: local addresses not allowed.",
|
||||
"denied" => "Codex blocked this request: domain denied by policy.",
|
||||
"method_not_allowed" => "Codex blocked this request: method not allowed in limited mode.",
|
||||
"mitm_required" => "Codex blocked this request: MITM required for limited HTTPS.",
|
||||
_ => "Codex blocked this request by network policy.",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,20 +4,30 @@ use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::Service;
|
||||
use rama::net::stream::SocketInfo;
|
||||
use rama::proxy::socks5::Socks5Acceptor;
|
||||
use rama::proxy::socks5::server::DefaultConnector;
|
||||
use rama::service::service_fn;
|
||||
use rama::tcp::client::Request as TcpRequest;
|
||||
use rama::tcp::client::service::TcpConnector;
|
||||
use rama::tcp::server::TcpListener;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::copy_bidirectional;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listener = TcpListener::build_with_state(state.clone())
|
||||
.bind(addr)
|
||||
.await
|
||||
.map_err(|err| anyhow!("bind SOCKS5 proxy: {err}"))?;
|
||||
|
||||
info!(addr = %addr, "SOCKS5 proxy listening");
|
||||
|
||||
match state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
info!(
|
||||
@@ -30,163 +40,93 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
warn!(error = %err, "failed to read network mode");
|
||||
}
|
||||
}
|
||||
loop {
|
||||
let (stream, peer_addr) = listener.accept().await?;
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handle_socks5_client(stream, peer_addr, state).await {
|
||||
warn!(error = %err, "SOCKS5 session ended with error");
|
||||
|
||||
let tcp_connector = TcpConnector::default();
|
||||
let policy_tcp_connector =
|
||||
service_fn(move |ctx: RamaContext<Arc<AppState>>, req: TcpRequest| {
|
||||
let tcp_connector = tcp_connector.clone();
|
||||
async move {
|
||||
let app_state = ctx.state().clone();
|
||||
let authority = req.authority().clone();
|
||||
let host = normalize_host(&authority.host().to_string());
|
||||
let port = authority.port();
|
||||
let client = ctx
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string());
|
||||
|
||||
match app_state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
client.clone(),
|
||||
None,
|
||||
Some(NetworkMode::Limited),
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"SOCKS blocked by method policy"
|
||||
);
|
||||
return Err(
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
|
||||
);
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate method policy");
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "proxy error").into());
|
||||
}
|
||||
}
|
||||
|
||||
match app_state.host_blocked(&host).await {
|
||||
Ok((true, reason)) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
reason = %reason,
|
||||
"SOCKS blocked"
|
||||
);
|
||||
return Err(
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
|
||||
);
|
||||
}
|
||||
Ok((false, _)) => {
|
||||
info!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
port = port,
|
||||
"SOCKS allowed"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate host");
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "proxy error").into());
|
||||
}
|
||||
}
|
||||
|
||||
tcp_connector.serve(ctx, req).await
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_socks5_client(
|
||||
mut stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<()> {
|
||||
let mut header = [0u8; 2];
|
||||
stream.read_exact(&mut header).await?;
|
||||
if header[0] != 0x05 {
|
||||
return Err(anyhow!("invalid SOCKS version"));
|
||||
}
|
||||
let nmethods = header[1] as usize;
|
||||
let mut methods = vec![0u8; nmethods];
|
||||
stream.read_exact(&mut methods).await?;
|
||||
stream.write_all(&[0x05, 0x00]).await?;
|
||||
let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector);
|
||||
let socks_acceptor = Socks5Acceptor::new().with_connector(socks_connector);
|
||||
|
||||
let mut req_header = [0u8; 4];
|
||||
stream.read_exact(&mut req_header).await?;
|
||||
if req_header[0] != 0x05 {
|
||||
return Err(anyhow!("invalid SOCKS request version"));
|
||||
}
|
||||
let cmd = req_header[1];
|
||||
if cmd != 0x01 {
|
||||
stream
|
||||
.write_all(&[0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
return Err(anyhow!("unsupported SOCKS command"));
|
||||
}
|
||||
let atyp = req_header[3];
|
||||
let host = match atyp {
|
||||
0x01 => {
|
||||
let mut addr = [0u8; 4];
|
||||
stream.read_exact(&mut addr).await?;
|
||||
format!("{}.{}.{}.{}", addr[0], addr[1], addr[2], addr[3])
|
||||
}
|
||||
0x03 => {
|
||||
let mut len_buf = [0u8; 1];
|
||||
stream.read_exact(&mut len_buf).await?;
|
||||
let len = len_buf[0] as usize;
|
||||
let mut domain = vec![0u8; len];
|
||||
stream.read_exact(&mut domain).await?;
|
||||
String::from_utf8_lossy(&domain).to_string()
|
||||
}
|
||||
0x04 => {
|
||||
stream
|
||||
.write_all(&[0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
return Err(anyhow!("ipv6 not supported"));
|
||||
}
|
||||
_ => {
|
||||
stream
|
||||
.write_all(&[0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
return Err(anyhow!("unknown address type"));
|
||||
}
|
||||
};
|
||||
|
||||
let mut port_buf = [0u8; 2];
|
||||
stream.read_exact(&mut port_buf).await?;
|
||||
let port = u16::from_be_bytes(port_buf);
|
||||
let normalized_host = normalize_host(&host);
|
||||
|
||||
match state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
let _ = state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
normalized_host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
Some(peer_addr.to_string()),
|
||||
None,
|
||||
Some(NetworkMode::Limited),
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
client = %peer_addr,
|
||||
host = %normalized_host,
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"SOCKS blocked by method policy"
|
||||
);
|
||||
stream
|
||||
.write_all(&[0x05, 0x02, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate method policy");
|
||||
stream
|
||||
.write_all(&[0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
match state.host_blocked(&normalized_host).await {
|
||||
Ok((true, reason)) => {
|
||||
let _ = state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
normalized_host.clone(),
|
||||
reason.clone(),
|
||||
Some(peer_addr.to_string()),
|
||||
None,
|
||||
None,
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(client = %peer_addr, host = %normalized_host, reason = %reason, "SOCKS blocked");
|
||||
stream
|
||||
.write_all(&[0x05, 0x02, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
Ok((false, _)) => {
|
||||
info!(
|
||||
client = %peer_addr,
|
||||
host = %normalized_host,
|
||||
port = port,
|
||||
"SOCKS allowed"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate host");
|
||||
stream
|
||||
.write_all(&[0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let target = format!("{host}:{port}");
|
||||
let mut upstream = match TcpStream::connect(&target).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "SOCKS connect failed");
|
||||
stream
|
||||
.write_all(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
stream
|
||||
.write_all(&[0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
.await?;
|
||||
|
||||
let _ = copy_bidirectional(&mut stream, &mut upstream).await;
|
||||
listener.serve(socks_acceptor).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -10,9 +10,6 @@ use anyhow::anyhow;
|
||||
use globset::GlobBuilder;
|
||||
use globset::GlobSet;
|
||||
use globset::GlobSetBuilder;
|
||||
use hyper::Client;
|
||||
use hyper::Method;
|
||||
use hyper::client::HttpConnector;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
@@ -72,16 +69,13 @@ struct ConfigState {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub(crate) client: Client<HttpConnector>,
|
||||
state: Arc<RwLock<ConfigState>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new(cfg_path: PathBuf) -> Result<Self> {
|
||||
let cfg_state = build_config_state(cfg_path)?;
|
||||
let client = Client::new();
|
||||
Ok(Self {
|
||||
client,
|
||||
state: Arc::new(RwLock::new(cfg_state)),
|
||||
})
|
||||
}
|
||||
@@ -171,7 +165,7 @@ impl AppState {
|
||||
.any(|p| p == path))
|
||||
}
|
||||
|
||||
pub async fn method_allowed(&self, method: &Method) -> Result<bool> {
|
||||
pub async fn method_allowed(&self, method: &str) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(method_allowed(guard.cfg.network_proxy.mode, method))
|
||||
|
||||
Reference in New Issue
Block a user