address feedback

This commit is contained in:
viyatb-oai
2025-12-23 18:03:55 -08:00
parent 127b89b4ed
commit 9d473922e3
12 changed files with 628 additions and 708 deletions

View File

@@ -3,97 +3,119 @@ use crate::responses::json_response;
use crate::responses::text_response;
use crate::state::AppState;
use anyhow::Result;
use hyper::Body;
use hyper::Method;
use hyper::Request;
use hyper::Response;
use hyper::Server;
use hyper::StatusCode;
use hyper::body::to_bytes;
use hyper::service::make_service_fn;
use hyper::service::service_fn;
use rama::Context as RamaContext;
use rama::http::Body;
use rama::http::Request;
use rama::http::Response;
use rama::http::StatusCode;
use rama::http::server::HttpServer;
use rama::service::service_fn;
use rama::tcp::server::TcpListener;
use serde::Deserialize;
use serde_json::json;
use serde::Serialize;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::error;
use tracing::info;
type ContextState = Arc<AppState>;
type AdminContext = RamaContext<ContextState>;
pub async fn run_admin_api(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
let make_svc = make_service_fn(move |_conn: &hyper::server::conn::AddrStream| {
let state = state.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
handle_admin_request(req, state.clone())
}))
}
});
let server = Server::bind(&addr).serve(make_svc);
info!(addr = %addr, "admin API listening");
server.await?;
let listener = TcpListener::build_with_state(state)
.bind(addr)
.await
.map_err(|err| anyhow::anyhow!("bind admin API: {err}"))?;
let server =
HttpServer::auto(rama::rt::Executor::new()).service(service_fn(handle_admin_request));
info!("admin API listening on {addr}");
listener.serve(server).await;
Ok(())
}
async fn handle_admin_request(
req: Request<Body>,
state: Arc<AppState>,
) -> Result<Response<Body>, Infallible> {
async fn handle_admin_request(ctx: AdminContext, req: Request) -> Result<Response, Infallible> {
const MODE_BODY_LIMIT: usize = 8 * 1024;
let state = ctx.state().clone();
let method = req.method().clone();
let path = req.uri().path().to_string();
let response = match (method, path.as_str()) {
(Method::GET, "/health") => Response::new(Body::from("ok")),
(Method::GET, "/config") => match state.current_cfg().await {
let response = match (method.as_str(), path.as_str()) {
("GET", "/health") => Response::new(Body::from("ok")),
("GET", "/config") => match state.current_cfg().await {
Ok(cfg) => json_response(&cfg),
Err(err) => {
error!(error = %err, "failed to load config");
error!("failed to load config: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
(Method::GET, "/patterns") => match state.current_patterns().await {
Ok((allow, deny)) => json_response(&json!({"allowed": allow, "denied": deny})),
("GET", "/patterns") => match state.current_patterns().await {
Ok((allow, deny)) => json_response(&PatternsResponse {
allowed: allow,
denied: deny,
}),
Err(err) => {
error!(error = %err, "failed to load patterns");
error!("failed to load patterns: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
(Method::GET, "/blocked") => match state.drain_blocked().await {
Ok(blocked) => json_response(&json!({ "blocked": blocked })),
("GET", "/blocked") => match state.drain_blocked().await {
Ok(blocked) => json_response(&BlockedResponse { blocked }),
Err(err) => {
error!(error = %err, "failed to read blocked queue");
error!("failed to read blocked queue: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
(Method::POST, "/mode") => {
let body = match to_bytes(req.into_body()).await {
Ok(bytes) => bytes,
Err(err) => {
error!(error = %err, "failed to read mode body");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
("POST", "/mode") => {
let mut body = req.into_body();
let mut buf: Vec<u8> = Vec::new();
loop {
let chunk = match body.chunk().await {
Ok(chunk) => chunk,
Err(err) => {
error!("failed to read mode body: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
}
};
let Some(chunk) = chunk else {
break;
};
if buf.len().saturating_add(chunk.len()) > MODE_BODY_LIMIT {
return Ok(text_response(
StatusCode::PAYLOAD_TOO_LARGE,
"body too large",
));
}
};
if body.is_empty() {
buf.extend_from_slice(&chunk);
}
if buf.is_empty() {
return Ok(text_response(StatusCode::BAD_REQUEST, "missing body"));
}
let update: ModeUpdate = match serde_json::from_slice(&body) {
let update: ModeUpdate = match serde_json::from_slice(&buf) {
Ok(update) => update,
Err(err) => {
error!(error = %err, "failed to parse mode update");
error!("failed to parse mode update: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid json"));
}
};
match state.set_network_mode(update.mode).await {
Ok(()) => json_response(&json!({"status": "ok", "mode": update.mode})),
Ok(()) => json_response(&ModeUpdateResponse {
status: "ok",
mode: update.mode,
}),
Err(err) => {
error!(error = %err, "mode update failed");
error!("mode update failed: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "mode update failed")
}
}
}
(Method::POST, "/reload") => match state.force_reload().await {
Ok(()) => json_response(&json!({"status": "reloaded"})),
("POST", "/reload") => match state.force_reload().await {
Ok(()) => json_response(&ReloadResponse { status: "reloaded" }),
Err(err) => {
error!(error = %err, "reload failed");
error!("reload failed: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "reload failed")
}
},
@@ -106,3 +128,25 @@ async fn handle_admin_request(
struct ModeUpdate {
mode: NetworkMode,
}
#[derive(Debug, Serialize)]
struct PatternsResponse {
allowed: Vec<String>,
denied: Vec<String>,
}
#[derive(Debug, Serialize)]
struct BlockedResponse<T> {
blocked: T,
}
#[derive(Debug, Serialize)]
struct ModeUpdateResponse {
status: &'static str,
mode: NetworkMode,
}
#[derive(Debug, Serialize)]
struct ReloadResponse {
status: &'static str,
}

View File

@@ -1,26 +1,15 @@
use anyhow::Context;
use anyhow::Result;
use codex_core::config::default_config_path;
use serde::Deserialize;
use serde::Serialize;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
#[serde(default)]
pub network_proxy: NetworkProxyConfig,
}
impl Default for Config {
fn default() -> Self {
Self {
network_proxy: NetworkProxyConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkProxyConfig {
#[serde(default)]
@@ -50,42 +39,26 @@ impl Default for NetworkProxyConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct NetworkPolicy {
#[serde(default, rename = "allowed_domains", alias = "allowedDomains")]
#[serde(default)]
pub allowed_domains: Vec<String>,
#[serde(default, rename = "denied_domains", alias = "deniedDomains")]
#[serde(default)]
pub denied_domains: Vec<String>,
#[serde(default, rename = "allow_unix_sockets", alias = "allowUnixSockets")]
#[serde(default)]
pub allow_unix_sockets: Vec<String>,
#[serde(default, rename = "allow_local_binding", alias = "allowLocalBinding")]
#[serde(default)]
pub allow_local_binding: bool,
}
impl Default for NetworkPolicy {
fn default() -> Self {
Self {
allowed_domains: Vec::new(),
denied_domains: Vec::new(),
allow_unix_sockets: Vec::new(),
allow_local_binding: false,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum NetworkMode {
Limited,
#[default]
Full,
}
impl Default for NetworkMode {
fn default() -> Self {
NetworkMode::Full
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MitmConfig {
#[serde(default)]
@@ -138,10 +111,6 @@ pub struct RuntimeConfig {
pub admin_addr: SocketAddr,
}
pub fn default_codex_config_path() -> Result<PathBuf> {
default_config_path().context("failed to resolve Codex config path")
}
pub fn resolve_runtime(cfg: &Config) -> RuntimeConfig {
let http_addr = resolve_addr(&cfg.network_proxy.proxy_url, 3128);
let admin_addr = resolve_addr(&cfg.network_proxy.admin_url, 8080);
@@ -155,22 +124,30 @@ pub fn resolve_runtime(cfg: &Config) -> RuntimeConfig {
}
fn resolve_addr(url: &str, default_port: u16) -> SocketAddr {
let (host, port) = parse_host_port(url, default_port);
let host = if host.eq_ignore_ascii_case("localhost") {
let addr_parts = parse_host_port(url, default_port);
let host = if addr_parts.host.eq_ignore_ascii_case("localhost") {
"127.0.0.1"
} else {
host
addr_parts.host
};
match host.parse::<IpAddr>() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => SocketAddr::from(([127, 0, 0, 1], port)),
Ok(ip) => SocketAddr::new(ip, addr_parts.port),
Err(_) => SocketAddr::from(([127, 0, 0, 1], addr_parts.port)),
}
}
fn parse_host_port(url: &str, default_port: u16) -> (&str, u16) {
struct SocketAddressParts<'a> {
host: &'a str,
port: u16,
}
fn parse_host_port(url: &str, default_port: u16) -> SocketAddressParts<'_> {
let trimmed = url.trim();
if trimmed.is_empty() {
return ("127.0.0.1", default_port);
return SocketAddressParts {
host: "127.0.0.1",
port: default_port,
};
}
let without_scheme = trimmed
.split_once("://")
@@ -182,22 +159,25 @@ fn parse_host_port(url: &str, default_port: u16) -> (&str, u16) {
.map(|(_, rest)| rest)
.unwrap_or(host_port);
if host_port.starts_with('[') {
if let Some(end) = host_port.find(']') {
let host = &host_port[1..end];
let port = host_port[end + 1..]
.strip_prefix(':')
.and_then(|port| port.parse::<u16>().ok())
.unwrap_or(default_port);
return (host, port);
}
if host_port.starts_with('[')
&& let Some(end) = host_port.find(']')
{
let host = &host_port[1..end];
let port = host_port[end + 1..]
.strip_prefix(':')
.and_then(|port| port.parse::<u16>().ok())
.unwrap_or(default_port);
return SocketAddressParts { host, port };
}
if let Some((host, port)) = host_port.rsplit_once(':') {
if let Ok(port) = port.parse::<u16>() {
return (host, port);
}
if let Some((host, port)) = host_port.rsplit_once(':')
&& let Ok(port) = port.parse::<u16>()
{
return SocketAddressParts { host, port };
}
(host_port, default_port)
SocketAddressParts {
host: host_port,
port: default_port,
}
}

View File

@@ -5,7 +5,6 @@ use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use rama::Context as RamaContext;
use rama::Layer;
use rama::Service;
@@ -14,6 +13,7 @@ use rama::http::Request;
use rama::http::Response;
use rama::http::StatusCode;
use rama::http::client::EasyHttpWebClient;
use rama::http::dep::http::uri::PathAndQuery;
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
use rama::http::layer::upgrade::UpgradeLayer;
@@ -41,7 +41,8 @@ pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()
let listener = TcpListener::build_with_state(state)
.bind(addr)
.await
.map_err(|err| anyhow!("bind HTTP proxy: {err}"))?;
.map_err(|err| anyhow::anyhow!(err))
.with_context(|| format!("bind HTTP proxy: {addr}"))?;
let http_service = HttpServer::auto(rama::rt::Executor::new()).service(
(
@@ -56,7 +57,7 @@ pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()
.into_layer(service_fn(http_plain_proxy)),
);
info!(addr = %addr, "HTTP proxy listening");
info!("HTTP proxy listening on {addr}");
listener.serve(http_service).await;
Ok(())
@@ -72,7 +73,7 @@ async fn http_connect_accept(
{
Ok(authority) => authority,
Err(err) => {
warn!(error = %err, "CONNECT missing authority");
warn!("CONNECT missing authority: {err}");
return Err(text_response(StatusCode::BAD_REQUEST, "missing authority"));
}
};
@@ -97,23 +98,16 @@ async fn http_connect_accept(
"http-connect".to_string(),
))
.await;
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
reason = %reason,
"CONNECT blocked"
);
let client = client.as_deref().unwrap_or_default();
warn!("CONNECT blocked (client={client}, host={host}, reason={reason})");
return Err(blocked_text(&reason));
}
Ok((false, _)) => {
info!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
"CONNECT allowed"
);
let client = client.as_deref().unwrap_or_default();
info!("CONNECT allowed (client={client}, host={host})");
}
Err(err) => {
error!(error = %err, "failed to evaluate host");
error!("failed to evaluate host for CONNECT {host}: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
@@ -121,7 +115,7 @@ async fn http_connect_accept(
let mode = match app_state.network_mode().await {
Ok(mode) => mode,
Err(err) => {
error!(error = %err, "failed to read network mode");
error!("failed to read network mode: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
@@ -129,7 +123,7 @@ async fn http_connect_accept(
let mitm_state = match app_state.mitm_state().await {
Ok(state) => state,
Err(err) => {
error!(error = %err, "failed to load MITM state");
error!("failed to load MITM state: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
@@ -145,12 +139,9 @@ async fn http_connect_accept(
"http-connect".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"CONNECT blocked; MITM required for read-only HTTPS in limited mode"
"CONNECT blocked; MITM required for read-only HTTPS in limited mode (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Err(blocked_text("mitm_required"));
}
@@ -186,7 +177,8 @@ async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(),
let host = normalize_host(&authority.host().to_string());
if let Some(mitm_state) = ctx.get::<Arc<mitm::MitmState>>().cloned() {
info!(host = %host, port = authority.port(), mode = ?mode, "CONNECT MITM enabled");
let port = authority.port();
info!("CONNECT MITM enabled (host={host}, port={port}, mode={mode:?})");
if let Err(err) = mitm::mitm_tunnel(
ctx,
upgraded,
@@ -197,14 +189,14 @@ async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(),
)
.await
{
warn!(error = %err, "MITM tunnel error");
warn!("MITM tunnel error: {err}");
}
return Ok(());
}
let forwarder = Forwarder::ctx();
if let Err(err) = forwarder.serve(ctx, upgraded).await {
warn!(error = %err, "tunnel error");
warn!("tunnel error: {err}");
}
Ok(())
}
@@ -216,7 +208,7 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
let method_allowed = match app_state.method_allowed(req.method().as_str()).await {
Ok(allowed) => allowed,
Err(err) => {
error!(error = %err, "failed to evaluate method policy");
error!("failed to evaluate method policy: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
@@ -225,21 +217,19 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
.headers()
.get("x-unix-socket")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string())
.map(ToString::to_string)
{
if !method_allowed {
let client = client.as_deref().unwrap_or_default();
let method = req.method();
warn!(
client = %client.as_deref().unwrap_or_default(),
method = %req.method(),
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"unix socket blocked by method policy"
"unix socket blocked by method policy (client={client}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Ok(json_blocked("unix-socket", "method_not_allowed"));
}
if !cfg!(target_os = "macos") {
warn!(path = %socket_path, "unix socket proxy unsupported on this platform");
warn!("unix socket proxy unsupported on this platform (path={socket_path})");
return Ok(text_response(
StatusCode::NOT_IMPLEMENTED,
"unix sockets unsupported",
@@ -248,15 +238,12 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
match app_state.is_unix_socket_allowed(&socket_path).await {
Ok(true) => {
info!(
client = %client.as_deref().unwrap_or_default(),
path = %socket_path,
"unix socket allowed"
);
let client = client.as_deref().unwrap_or_default();
info!("unix socket allowed (client={client}, path={socket_path})");
match proxy_via_unix_socket(ctx, req, &socket_path).await {
Ok(resp) => return Ok(resp),
Err(err) => {
warn!(error = %err, "unix socket proxy failed");
warn!("unix socket proxy failed: {err}");
return Ok(text_response(
StatusCode::BAD_GATEWAY,
"unix socket proxy failed",
@@ -265,15 +252,12 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
}
}
Ok(false) => {
warn!(
client = %client.as_deref().unwrap_or_default(),
path = %socket_path,
"unix socket blocked"
);
let client = client.as_deref().unwrap_or_default();
warn!("unix socket blocked (client={client}, path={socket_path})");
return Ok(json_blocked("unix-socket", "not_allowed"));
}
Err(err) => {
warn!(error = %err, "unix socket check failed");
warn!("unix socket check failed: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
@@ -285,7 +269,7 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
{
Ok(authority) => authority,
Err(err) => {
warn!(error = %err, "missing host");
warn!("missing host: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
}
};
@@ -303,17 +287,13 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
"http".to_string(),
))
.await;
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
reason = %reason,
"request blocked"
);
let client = client.as_deref().unwrap_or_default();
warn!("request blocked (client={client}, host={host}, reason={reason})");
return Ok(json_blocked(&host, &reason));
}
Ok((false, _)) => {}
Err(err) => {
error!(error = %err, "failed to evaluate host");
error!("failed to evaluate host for {host}: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
@@ -329,29 +309,23 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
"http".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
let method = req.method();
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
method = %req.method(),
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"request blocked by method policy"
"request blocked by method policy (client={client}, host={host}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Ok(json_blocked(&host, "method_not_allowed"));
}
info!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
method = %req.method(),
"request allowed"
);
let client = client.as_deref().unwrap_or_default();
let method = req.method();
info!("request allowed (client={client}, host={host}, method={method})");
let client = EasyHttpWebClient::default();
match client.serve(ctx, req).await {
Ok(resp) => Ok(resp),
Err(err) => {
warn!(error = %err, "upstream request failed");
warn!("upstream request failed: {err}");
Ok(text_response(StatusCode::BAD_GATEWAY, "upstream failure"))
}
}
@@ -377,7 +351,7 @@ async fn proxy_via_unix_socket(
let path = parts
.uri
.path_and_query()
.map(|pq| pq.as_str())
.map(PathAndQuery::as_str)
.unwrap_or("/");
parts.uri = path
.parse()

View File

@@ -41,8 +41,7 @@ pub async fn run_main(args: Args) -> Result<()> {
warn!("allowUnixSockets is macOS-only; requests will be rejected on this platform");
}
let cfg_path = config::default_codex_config_path()?;
let state = Arc::new(AppState::new(cfg_path).await?);
let state = Arc::new(AppState::new().await?);
let runtime = config::resolve_runtime(&state.current_cfg().await?);
let http_addr: SocketAddr = runtime.http_addr;

View File

@@ -1,4 +1,3 @@
#[cfg(feature = "mitm")]
mod imp {
use crate::config::MitmConfig;
use crate::config::NetworkMode;
@@ -22,6 +21,7 @@ mod imp {
use rama::http::Response;
use rama::http::StatusCode;
use rama::http::Uri;
use rama::http::dep::http::uri::PathAndQuery;
use rama::http::header::HOST;
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
@@ -44,20 +44,19 @@ mod imp {
use tracing::info;
use tracing::warn;
use rcgen::BasicConstraints;
use rcgen::Certificate;
use rcgen::CertificateParams;
use rcgen::DistinguishedName;
use rcgen::DnType;
use rcgen::ExtendedKeyUsagePurpose;
use rcgen::IsCa;
use rcgen::KeyPair;
use rcgen::KeyUsagePurpose;
use rcgen::SanType;
use rcgen_rama::BasicConstraints;
use rcgen_rama::CertificateParams;
use rcgen_rama::DistinguishedName;
use rcgen_rama::DnType;
use rcgen_rama::ExtendedKeyUsagePurpose;
use rcgen_rama::IsCa;
use rcgen_rama::Issuer;
use rcgen_rama::KeyPair;
use rcgen_rama::KeyUsagePurpose;
use rcgen_rama::SanType;
pub struct MitmState {
ca_key: KeyPair,
ca_cert: Certificate,
issuer: Issuer<'static, KeyPair>,
upstream: rama::service::BoxService<Arc<AppState>, Request, Response, OpaqueError>,
inspect: bool,
max_body_bytes: usize,
@@ -67,11 +66,8 @@ mod imp {
pub fn new(cfg: &MitmConfig) -> Result<Self> {
let (ca_cert_pem, ca_key_pem) = load_or_create_ca(cfg)?;
let ca_key = KeyPair::from_pem(&ca_key_pem).context("failed to parse CA key")?;
let ca_params = CertificateParams::from_ca_cert_pem(&ca_cert_pem)
let issuer: Issuer<'static, KeyPair> = Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
.context("failed to parse CA cert")?;
let ca_cert = ca_params
.self_signed(&ca_key)
.context("failed to reconstruct CA cert")?;
let tls_config = rama::tls::rustls::client::TlsConnectorData::new_http_auto()
.context("create upstream TLS config")?;
@@ -84,8 +80,7 @@ mod imp {
.boxed();
Ok(Self {
ca_key,
ca_cert,
issuer,
upstream,
inspect: cfg.inspect,
max_body_bytes: cfg.max_body_bytes,
@@ -93,8 +88,7 @@ mod imp {
}
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
let (cert_pem, key_pem) =
issue_host_certificate_pem(host, &self.ca_cert, &self.ca_key)?;
let (cert_pem, key_pem) = issue_host_certificate_pem(host, &self.issuer)?;
let cert_chain = pemfile::certs(&mut BufReader::new(cert_pem.as_bytes()))
.collect::<Result<Vec<_>, _>>()
.context("failed to parse host cert PEM")?;
@@ -160,7 +154,7 @@ mod imp {
let response = match forward_request(ctx, req).await {
Ok(resp) => resp,
Err(err) => {
warn!(error = %err, "MITM upstream request failed");
warn!("MITM upstream request failed: {err}");
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
}
};
@@ -201,11 +195,7 @@ mod imp {
if let Some(request_host) = extract_request_host(&req) {
let normalized = normalize_host(&request_host);
if !normalized.is_empty() && normalized != target_host {
warn!(
target = %target_host,
request_host = %normalized,
"MITM host mismatch"
);
warn!("MITM host mismatch (target={target_host}, request_host={normalized})");
return Ok(text_response(StatusCode::BAD_REQUEST, "host mismatch"));
}
}
@@ -223,12 +213,7 @@ mod imp {
))
.await;
warn!(
host = %target_host,
method = %method,
path = %path,
mode = ?mode,
allowed_methods = "GET, HEAD, OPTIONS",
"MITM blocked by method policy"
"MITM blocked by method policy (host={target_host}, method={method}, path={path}, mode={mode:?}, allowed_methods=GET, HEAD, OPTIONS)"
);
return Ok(blocked_text("method_not_allowed"));
}
@@ -355,27 +340,23 @@ mod imp {
impl BodyLoggable for RequestLogContext {
fn log(self, len: usize, truncated: bool) {
let host = self.host;
let method = self.method;
let path = self.path;
info!(
host = %self.host,
method = %self.method,
path = %self.path,
body_len = len,
truncated = truncated,
"MITM inspected request body"
"MITM inspected request body (host={host}, method={method}, path={path}, body_len={len}, truncated={truncated})"
);
}
}
impl BodyLoggable for ResponseLogContext {
fn log(self, len: usize, truncated: bool) {
let host = self.host;
let method = self.method;
let path = self.path;
let status = self.status;
info!(
host = %self.host,
method = %self.method,
path = %self.path,
status = %self.status,
body_len = len,
truncated = truncated,
"MITM inspected response body"
"MITM inspected response body (host={host}, method={method}, path={path}, status={status}, body_len={len}, truncated={truncated})"
);
}
}
@@ -384,7 +365,7 @@ mod imp {
req.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string())
.map(ToString::to_string)
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
}
@@ -410,15 +391,14 @@ mod imp {
fn path_and_query(uri: &Uri) -> String {
uri.path_and_query()
.map(|pq| pq.as_str())
.map(PathAndQuery::as_str)
.unwrap_or("/")
.to_string()
}
fn issue_host_certificate_pem(
host: &str,
ca_cert: &Certificate,
ca_key: &KeyPair,
issuer: &Issuer<'_, KeyPair>,
) -> Result<(String, String)> {
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
let mut params = CertificateParams::new(Vec::new())
@@ -436,10 +416,10 @@ mod imp {
KeyUsagePurpose::KeyEncipherment,
];
let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
let key_pair = KeyPair::generate_for(&rcgen_rama::PKCS_ECDSA_P256_SHA256)
.map_err(|err| anyhow!("failed to generate host key pair: {err}"))?;
let cert = params
.signed_by(&key_pair, ca_cert, ca_key)
.signed_by(&key_pair, issuer)
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
Ok((cert.pem(), key_pair.serialize_pem()))
@@ -472,11 +452,9 @@ mod imp {
let (cert_pem, key_pem) = generate_ca()?;
write_private_file(cert_path, cert_pem.as_bytes(), 0o644)?;
write_private_file(key_path, key_pem.as_bytes(), 0o600)?;
info!(
cert_path = %cert_path.display(),
key_path = %key_path.display(),
"generated MITM CA"
);
let cert_path = cert_path.display();
let key_path = key_path.display();
info!("generated MITM CA (cert_path={cert_path}, key_path={key_path})");
Ok((cert_pem, key_pem))
}
@@ -492,7 +470,7 @@ mod imp {
dn.push(DnType::CommonName, "network_proxy MITM CA");
params.distinguished_name = dn;
let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
let key_pair = KeyPair::generate_for(&rcgen_rama::PKCS_ECDSA_P256_SHA256)
.map_err(|err| anyhow!("failed to generate CA key pair: {err}"))?;
let cert = params
.self_signed(&key_pair)
@@ -557,45 +535,4 @@ mod imp {
}
}
#[cfg(not(feature = "mitm"))]
mod imp {
use crate::config::MitmConfig;
use crate::config::NetworkMode;
use crate::state::AppState;
use anyhow::Result;
use anyhow::anyhow;
use rama::Context as RamaContext;
use rama::http::layer::upgrade::Upgraded;
use std::sync::Arc;
#[derive(Debug)]
pub struct MitmState;
#[allow(dead_code)]
impl MitmState {
pub fn new(_cfg: &MitmConfig) -> Result<Self> {
Err(anyhow!("MITM feature disabled at build time"))
}
pub fn inspect_enabled(&self) -> bool {
false
}
pub fn max_body_bytes(&self) -> usize {
0
}
}
pub async fn mitm_tunnel(
_ctx: RamaContext<Arc<AppState>>,
_upgraded: Upgraded,
_host: &str,
_port: u16,
_mode: NetworkMode,
_state: Arc<MitmState>,
) -> Result<()> {
Err(anyhow!("MITM feature disabled at build time"))
}
}
pub use imp::*;

View File

@@ -21,10 +21,10 @@ pub fn is_loopback_host(host: &str) -> bool {
pub fn normalize_host(host: &str) -> String {
let host = host.trim();
if host.starts_with('[') {
if let Some(end) = host.find(']') {
return host[1..end].to_ascii_lowercase();
}
if host.starts_with('[')
&& let Some(end) = host.find(']')
{
return host[1..end].to_ascii_lowercase();
}
host.split(':').next().unwrap_or("").to_ascii_lowercase()
}

View File

@@ -1,9 +1,9 @@
use hyper::Body;
use hyper::Response;
use hyper::StatusCode;
use rama::http::Body;
use rama::http::Response;
use rama::http::StatusCode;
use serde::Serialize;
pub fn text_response(status: StatusCode, body: &str) -> Response<Body> {
pub fn text_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.header("content-type", "text/plain")
@@ -11,7 +11,7 @@ pub fn text_response(status: StatusCode, body: &str) -> Response<Body> {
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}
pub fn json_response<T: Serialize>(value: &T) -> Response<Body> {
pub fn json_response<T: Serialize>(value: &T) -> Response {
let body = match serde_json::to_string(value) {
Ok(body) => body,
Err(_) => "{}".to_string(),

View File

@@ -26,24 +26,21 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
.await
.map_err(|err| anyhow!("bind SOCKS5 proxy: {err}"))?;
info!(addr = %addr, "SOCKS5 proxy listening");
info!("SOCKS5 proxy listening on {addr}");
match state.network_mode().await {
Ok(NetworkMode::Limited) => {
info!(
mode = "limited",
"SOCKS5 is blocked in limited mode; set mode=\"full\" to allow SOCKS5"
);
info!("SOCKS5 is blocked in limited mode; set mode=\"full\" to allow SOCKS5");
}
Ok(NetworkMode::Full) => {}
Err(err) => {
warn!(error = %err, "failed to read network mode");
warn!("failed to read network mode: {err}");
}
}
let tcp_connector = TcpConnector::default();
let policy_tcp_connector =
service_fn(move |ctx: RamaContext<Arc<AppState>>, req: TcpRequest| {
let policy_tcp_connector = service_fn(
move |ctx: RamaContext<Arc<AppState>>, req: TcpRequest| {
let tcp_connector = tcp_connector.clone();
async move {
let app_state = ctx.state().clone();
@@ -66,12 +63,9 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
"socks5".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"SOCKS blocked by method policy"
"SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Err(
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
@@ -79,8 +73,8 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
}
Ok(NetworkMode::Full) => {}
Err(err) => {
error!(error = %err, "failed to evaluate method policy");
return Err(io::Error::new(io::ErrorKind::Other, "proxy error").into());
error!("failed to evaluate method policy: {err}");
return Err(io::Error::other("proxy error").into());
}
}
@@ -96,33 +90,26 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
"socks5".to_string(),
))
.await;
warn!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
reason = %reason,
"SOCKS blocked"
);
let client = client.as_deref().unwrap_or_default();
warn!("SOCKS blocked (client={client}, host={host}, reason={reason})");
return Err(
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
);
}
Ok((false, _)) => {
info!(
client = %client.as_deref().unwrap_or_default(),
host = %host,
port = port,
"SOCKS allowed"
);
let client = client.as_deref().unwrap_or_default();
info!("SOCKS allowed (client={client}, host={host}, port={port})");
}
Err(err) => {
error!(error = %err, "failed to evaluate host");
return Err(io::Error::new(io::ErrorKind::Other, "proxy error").into());
error!("failed to evaluate host: {err}");
return Err(io::Error::other("proxy error").into());
}
}
tcp_connector.serve(ctx, req).await
}
});
},
);
let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector);
let socks_acceptor = Socks5Acceptor::new().with_connector(socks_connector);

View File

@@ -6,10 +6,15 @@ use crate::policy::is_loopback_host;
use crate::policy::method_allowed;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use codex_app_server_protocol::ConfigLayerSource;
use codex_core::config::CONFIG_TOML_FILE;
use codex_core::config::ConfigBuilder;
use codex_core::config::Constrained;
use codex_core::config::ConstraintError;
use globset::GlobBuilder;
use globset::GlobSet;
use globset::GlobSetBuilder;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashSet;
use std::collections::VecDeque;
@@ -17,7 +22,7 @@ use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use time::OffsetDateTime;
use tokio::sync::RwLock;
use tracing::info;
use tracing::warn;
@@ -58,7 +63,7 @@ impl BlockedRequest {
#[derive(Clone)]
struct ConfigState {
cfg: Config,
config: Config,
mtime: Option<SystemTime>,
allow_set: GlobSet,
deny_set: GlobSet,
@@ -73,8 +78,8 @@ pub struct AppState {
}
impl AppState {
pub async fn new(cfg_path: PathBuf) -> Result<Self> {
let cfg_state = build_config_state(cfg_path)?;
pub async fn new() -> Result<Self> {
let cfg_state = build_config_state().await?;
Ok(Self {
state: Arc::new(RwLock::new(cfg_state)),
})
@@ -83,33 +88,34 @@ impl AppState {
pub async fn current_cfg(&self) -> Result<Config> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.cfg.clone())
Ok(guard.config.clone())
}
pub async fn current_patterns(&self) -> Result<(Vec<String>, Vec<String>)> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok((
guard.cfg.network_proxy.policy.allowed_domains.clone(),
guard.cfg.network_proxy.policy.denied_domains.clone(),
guard.config.network_proxy.policy.allowed_domains.clone(),
guard.config.network_proxy.policy.denied_domains.clone(),
))
}
pub async fn force_reload(&self) -> Result<()> {
let mut guard = self.state.write().await;
let previous_cfg = guard.cfg.clone();
let previous_cfg = guard.config.clone();
let blocked = guard.blocked.clone();
let cfg_path = guard.cfg_path.clone();
match build_config_state(cfg_path.clone()) {
match build_config_state().await {
Ok(mut new_state) => {
log_policy_changes(&previous_cfg, &new_state.cfg);
log_policy_changes(&previous_cfg, &new_state.config);
new_state.blocked = blocked;
*guard = new_state;
info!(path = %cfg_path.display(), "reloaded config");
let path = guard.cfg_path.display();
info!("reloaded config from {path}");
Ok(())
}
Err(err) => {
warn!(error = %err, path = %cfg_path.display(), "failed to reload config; keeping previous config");
let path = guard.cfg_path.display();
warn!("failed to reload config from {path}: {err}; keeping previous config");
Err(err)
}
}
@@ -123,12 +129,12 @@ impl AppState {
}
let is_loopback = is_loopback_host(host);
if is_loopback
&& !guard.cfg.network_proxy.policy.allow_local_binding
&& !guard.config.network_proxy.policy.allow_local_binding
&& !guard.allow_set.is_match(host)
{
return Ok((true, "not_allowed_local".to_string()));
}
if guard.cfg.network_proxy.policy.allowed_domains.is_empty()
if guard.config.network_proxy.policy.allowed_domains.is_empty()
|| !guard.allow_set.is_match(host)
{
return Ok((true, "not_allowed".to_string()));
@@ -157,7 +163,7 @@ impl AppState {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard
.cfg
.config
.network_proxy
.policy
.allow_unix_sockets
@@ -168,20 +174,20 @@ impl AppState {
pub async fn method_allowed(&self, method: &str) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(method_allowed(guard.cfg.network_proxy.mode, method))
Ok(method_allowed(guard.config.network_proxy.mode, method))
}
pub async fn network_mode(&self) -> Result<NetworkMode> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.cfg.network_proxy.mode)
Ok(guard.config.network_proxy.mode)
}
pub async fn set_network_mode(&self, mode: NetworkMode) -> Result<()> {
self.reload_if_needed().await?;
let mut guard = self.state.write().await;
guard.cfg.network_proxy.mode = mode;
info!(mode = ?mode, "updated network mode");
guard.config.network_proxy.mode = mode;
info!("updated network mode to {mode:?}");
Ok(())
}
@@ -195,7 +201,9 @@ impl AppState {
let needs_reload = {
let guard = self.state.read().await;
if !guard.cfg_path.exists() {
true
// If the config file is missing, only reload when it *used to* exist (mtime set).
// This avoids forcing a reload on every request when running with the default config.
guard.mtime.is_some()
} else {
let metadata = std::fs::metadata(&guard.cfg_path).ok();
match (metadata.and_then(|m| m.modified().ok()), guard.mtime) {
@@ -214,28 +222,32 @@ impl AppState {
}
}
fn build_config_state(cfg_path: PathBuf) -> Result<ConfigState> {
let mut cfg = if cfg_path.exists() {
load_config_from_path(&cfg_path).with_context(|| {
format!(
"failed to load config from {}",
cfg_path.as_path().display()
)
})?
} else {
Config::default()
};
resolve_mitm_paths(&mut cfg, &cfg_path);
async fn build_config_state() -> Result<ConfigState> {
let codex_cfg = ConfigBuilder::default()
.build()
.await
.context("failed to load Codex config")?;
let cfg_path = codex_cfg.codex_home.join(CONFIG_TOML_FILE);
let merged_toml = codex_cfg.config_layer_stack.effective_config();
let mut config: Config = merged_toml
.try_into()
.context("failed to deserialize network proxy config")?;
enforce_trusted_constraints(&codex_cfg.config_layer_stack, &config)?;
resolve_mitm_paths(&mut config, &cfg_path);
let mtime = cfg_path.metadata().and_then(|m| m.modified()).ok();
let deny_set = compile_globset(&cfg.network_proxy.policy.denied_domains)?;
let allow_set = compile_globset(&cfg.network_proxy.policy.allowed_domains)?;
let mitm = if cfg.network_proxy.mitm.enabled {
build_mitm_state(&cfg.network_proxy.mitm)?
let deny_set = compile_globset(&config.network_proxy.policy.denied_domains)?;
let allow_set = compile_globset(&config.network_proxy.policy.allowed_domains)?;
let mitm = if config.network_proxy.mitm.enabled {
build_mitm_state(&config.network_proxy.mitm)?
} else {
None
};
Ok(ConfigState {
cfg,
config,
mtime,
allow_set,
deny_set,
@@ -245,25 +257,248 @@ fn build_config_state(cfg_path: PathBuf) -> Result<ConfigState> {
})
}
fn resolve_mitm_paths(cfg: &mut Config, cfg_path: &Path) {
fn resolve_mitm_paths(config: &mut Config, cfg_path: &Path) {
let base = cfg_path.parent().unwrap_or_else(|| Path::new("."));
if cfg.network_proxy.mitm.ca_cert_path.is_relative() {
cfg.network_proxy.mitm.ca_cert_path = base.join(&cfg.network_proxy.mitm.ca_cert_path);
if config.network_proxy.mitm.ca_cert_path.is_relative() {
config.network_proxy.mitm.ca_cert_path = base.join(&config.network_proxy.mitm.ca_cert_path);
}
if cfg.network_proxy.mitm.ca_key_path.is_relative() {
cfg.network_proxy.mitm.ca_key_path = base.join(&cfg.network_proxy.mitm.ca_key_path);
if config.network_proxy.mitm.ca_key_path.is_relative() {
config.network_proxy.mitm.ca_key_path = base.join(&config.network_proxy.mitm.ca_key_path);
}
}
fn build_mitm_state(_cfg: &MitmConfig) -> Result<Option<Arc<MitmState>>> {
#[cfg(feature = "mitm")]
fn build_mitm_state(config: &MitmConfig) -> Result<Option<Arc<MitmState>>> {
Ok(Some(Arc::new(MitmState::new(config)?)))
}
#[derive(Debug, Default, Deserialize)]
struct PartialConfig {
#[serde(default)]
network_proxy: PartialNetworkProxyConfig,
}
#[derive(Debug, Default, Deserialize)]
struct PartialNetworkProxyConfig {
enabled: Option<bool>,
mode: Option<NetworkMode>,
#[serde(default)]
policy: PartialNetworkPolicy,
}
#[derive(Debug, Default, Deserialize)]
struct PartialNetworkPolicy {
#[serde(default)]
allowed_domains: Option<Vec<String>>,
#[serde(default)]
denied_domains: Option<Vec<String>>,
#[serde(default)]
allow_unix_sockets: Option<Vec<String>>,
#[serde(default)]
allow_local_binding: Option<bool>,
}
#[derive(Debug, Default)]
struct NetworkProxyConstraints {
enabled: Option<bool>,
mode: Option<NetworkMode>,
allowed_domains: Option<Vec<String>>,
denied_domains: Option<Vec<String>>,
allow_unix_sockets: Option<Vec<String>>,
allow_local_binding: Option<bool>,
}
fn enforce_trusted_constraints(
layers: &codex_core::config_loader::ConfigLayerStack,
config: &Config,
) -> Result<()> {
let constraints = network_proxy_constraints_from_trusted_layers(layers)?;
validate_policy_against_constraints(config, &constraints)
.context("network proxy constraints")?;
Ok(())
}
fn network_proxy_constraints_from_trusted_layers(
layers: &codex_core::config_loader::ConfigLayerStack,
) -> Result<NetworkProxyConstraints> {
let mut constraints = NetworkProxyConstraints::default();
for layer in layers
.get_layers(codex_core::config_loader::ConfigLayerStackOrdering::LowestPrecedenceFirst)
{
return Ok(Some(Arc::new(MitmState::new(_cfg)?)));
if is_user_controlled_layer(&layer.name) {
continue;
}
let partial: PartialConfig = layer
.config
.clone()
.try_into()
.context("failed to deserialize trusted config layer")?;
if let Some(enabled) = partial.network_proxy.enabled {
constraints.enabled = Some(enabled);
}
if let Some(mode) = partial.network_proxy.mode {
constraints.mode = Some(mode);
}
if let Some(allowed_domains) = partial.network_proxy.policy.allowed_domains {
constraints.allowed_domains = Some(allowed_domains);
}
if let Some(denied_domains) = partial.network_proxy.policy.denied_domains {
constraints.denied_domains = Some(denied_domains);
}
if let Some(allow_unix_sockets) = partial.network_proxy.policy.allow_unix_sockets {
constraints.allow_unix_sockets = Some(allow_unix_sockets);
}
if let Some(allow_local_binding) = partial.network_proxy.policy.allow_local_binding {
constraints.allow_local_binding = Some(allow_local_binding);
}
}
#[cfg(not(feature = "mitm"))]
{
warn!("MITM enabled in config but binary built without mitm feature");
Ok(None)
Ok(constraints)
}
fn is_user_controlled_layer(layer: &ConfigLayerSource) -> bool {
matches!(
layer,
ConfigLayerSource::User { .. }
| ConfigLayerSource::Project { .. }
| ConfigLayerSource::SessionFlags
)
}
fn validate_policy_against_constraints(
config: &Config,
constraints: &NetworkProxyConstraints,
) -> std::result::Result<(), ConstraintError> {
let enabled = config.network_proxy.enabled;
if let Some(max_enabled) = constraints.enabled {
let _ = Constrained::new(enabled, move |candidate| {
if *candidate && !max_enabled {
Err(ConstraintError::invalid_value(
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
})?;
}
if let Some(max_mode) = constraints.mode {
let _ = Constrained::new(config.network_proxy.mode, move |candidate| {
if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
Err(ConstraintError::invalid_value(
format!("{candidate:?}"),
format!("{max_mode:?} or more restrictive"),
))
} else {
Ok(())
}
})?;
}
if let Some(allow_local_binding) = constraints.allow_local_binding {
let _ = Constrained::new(
config.network_proxy.policy.allow_local_binding,
move |candidate| {
if *candidate && !allow_local_binding {
Err(ConstraintError::invalid_value(
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
},
)?;
}
if let Some(allowed_domains) = &constraints.allowed_domains {
let allowed_set: HashSet<String> = allowed_domains
.iter()
.map(|s| s.to_ascii_lowercase())
.collect();
let _ = Constrained::new(
config.network_proxy.policy.allowed_domains.clone(),
move |candidate| {
let mut invalid = Vec::new();
for entry in candidate {
if !allowed_set.contains(&entry.to_ascii_lowercase()) {
invalid.push(entry.clone());
}
}
if invalid.is_empty() {
Ok(())
} else {
Err(ConstraintError::invalid_value(
format!("{invalid:?}"),
"subset of managed allowed_domains",
))
}
},
)?;
}
if let Some(denied_domains) = &constraints.denied_domains {
let required_set: HashSet<String> = denied_domains
.iter()
.map(|s| s.to_ascii_lowercase())
.collect();
let _ = Constrained::new(
config.network_proxy.policy.denied_domains.clone(),
move |candidate| {
let candidate_set: HashSet<String> =
candidate.iter().map(|s| s.to_ascii_lowercase()).collect();
let missing: Vec<String> = required_set
.iter()
.filter(|entry| !candidate_set.contains(*entry))
.cloned()
.collect();
if missing.is_empty() {
Ok(())
} else {
Err(ConstraintError::invalid_value(
"missing managed denied_domains entries",
format!("{missing:?}"),
))
}
},
)?;
}
if let Some(allow_unix_sockets) = &constraints.allow_unix_sockets {
let allowed_set: HashSet<String> = allow_unix_sockets
.iter()
.map(|s| s.to_ascii_lowercase())
.collect();
let _ = Constrained::new(
config.network_proxy.policy.allow_unix_sockets.clone(),
move |candidate| {
let mut invalid = Vec::new();
for entry in candidate {
if !allowed_set.contains(&entry.to_ascii_lowercase()) {
invalid.push(entry.clone());
}
}
if invalid.is_empty() {
Ok(())
} else {
Err(ConstraintError::invalid_value(
format!("{invalid:?}"),
"subset of managed allow_unix_sockets",
))
}
},
)?;
}
Ok(())
}
fn network_mode_rank(mode: NetworkMode) -> u8 {
match mode {
NetworkMode::Limited => 0,
NetworkMode::Full => 1,
}
}
@@ -317,7 +552,7 @@ fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]
for entry in next {
let key = entry.to_ascii_lowercase();
if seen_next.insert(key.clone()) && !previous_set.contains(&key) {
info!(list = list_name, entry = %entry, "config entry added");
info!("config entry added to {list_name}: {entry}");
}
}
@@ -325,20 +560,11 @@ fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]
for entry in previous {
let key = entry.to_ascii_lowercase();
if seen_previous.insert(key.clone()) && !next_set.contains(&key) {
info!(list = list_name, entry = %entry, "config entry removed");
info!("config entry removed from {list_name}: {entry}");
}
}
}
fn unix_timestamp() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_secs() as i64)
.unwrap_or(0)
}
fn load_config_from_path(path: &Path) -> Result<Config> {
let raw = std::fs::read_to_string(path)
.with_context(|| format!("unable to read config file {}", path.display()))?;
toml::from_str(&raw).map_err(|err| anyhow!("unable to parse config: {err}"))
OffsetDateTime::now_utc().unix_timestamp()
}