mirror of
https://github.com/openai/codex.git
synced 2026-04-29 17:06:51 +00:00
address feedback
This commit is contained in:
@@ -3,97 +3,119 @@ use crate::responses::json_response;
|
||||
use crate::responses::text_response;
|
||||
use crate::state::AppState;
|
||||
use anyhow::Result;
|
||||
use hyper::Body;
|
||||
use hyper::Method;
|
||||
use hyper::Request;
|
||||
use hyper::Response;
|
||||
use hyper::Server;
|
||||
use hyper::StatusCode;
|
||||
use hyper::body::to_bytes;
|
||||
use hyper::service::make_service_fn;
|
||||
use hyper::service::service_fn;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::http::Body;
|
||||
use rama::http::Request;
|
||||
use rama::http::Response;
|
||||
use rama::http::StatusCode;
|
||||
use rama::http::server::HttpServer;
|
||||
use rama::service::service_fn;
|
||||
use rama::tcp::server::TcpListener;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use serde::Serialize;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
|
||||
type ContextState = Arc<AppState>;
|
||||
type AdminContext = RamaContext<ContextState>;
|
||||
|
||||
pub async fn run_admin_api(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
let make_svc = make_service_fn(move |_conn: &hyper::server::conn::AddrStream| {
|
||||
let state = state.clone();
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |req| {
|
||||
handle_admin_request(req, state.clone())
|
||||
}))
|
||||
}
|
||||
});
|
||||
let server = Server::bind(&addr).serve(make_svc);
|
||||
info!(addr = %addr, "admin API listening");
|
||||
server.await?;
|
||||
let listener = TcpListener::build_with_state(state)
|
||||
.bind(addr)
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!("bind admin API: {err}"))?;
|
||||
|
||||
let server =
|
||||
HttpServer::auto(rama::rt::Executor::new()).service(service_fn(handle_admin_request));
|
||||
info!("admin API listening on {addr}");
|
||||
listener.serve(server).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_admin_request(
|
||||
req: Request<Body>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
async fn handle_admin_request(ctx: AdminContext, req: Request) -> Result<Response, Infallible> {
|
||||
const MODE_BODY_LIMIT: usize = 8 * 1024;
|
||||
|
||||
let state = ctx.state().clone();
|
||||
let method = req.method().clone();
|
||||
let path = req.uri().path().to_string();
|
||||
let response = match (method, path.as_str()) {
|
||||
(Method::GET, "/health") => Response::new(Body::from("ok")),
|
||||
(Method::GET, "/config") => match state.current_cfg().await {
|
||||
let response = match (method.as_str(), path.as_str()) {
|
||||
("GET", "/health") => Response::new(Body::from("ok")),
|
||||
("GET", "/config") => match state.current_cfg().await {
|
||||
Ok(cfg) => json_response(&cfg),
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to load config");
|
||||
error!("failed to load config: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
(Method::GET, "/patterns") => match state.current_patterns().await {
|
||||
Ok((allow, deny)) => json_response(&json!({"allowed": allow, "denied": deny})),
|
||||
("GET", "/patterns") => match state.current_patterns().await {
|
||||
Ok((allow, deny)) => json_response(&PatternsResponse {
|
||||
allowed: allow,
|
||||
denied: deny,
|
||||
}),
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to load patterns");
|
||||
error!("failed to load patterns: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
(Method::GET, "/blocked") => match state.drain_blocked().await {
|
||||
Ok(blocked) => json_response(&json!({ "blocked": blocked })),
|
||||
("GET", "/blocked") => match state.drain_blocked().await {
|
||||
Ok(blocked) => json_response(&BlockedResponse { blocked }),
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to read blocked queue");
|
||||
error!("failed to read blocked queue: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
(Method::POST, "/mode") => {
|
||||
let body = match to_bytes(req.into_body()).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to read mode body");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
|
||||
("POST", "/mode") => {
|
||||
let mut body = req.into_body();
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
loop {
|
||||
let chunk = match body.chunk().await {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
error!("failed to read mode body: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
|
||||
}
|
||||
};
|
||||
let Some(chunk) = chunk else {
|
||||
break;
|
||||
};
|
||||
|
||||
if buf.len().saturating_add(chunk.len()) > MODE_BODY_LIMIT {
|
||||
return Ok(text_response(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"body too large",
|
||||
));
|
||||
}
|
||||
};
|
||||
if body.is_empty() {
|
||||
buf.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if buf.is_empty() {
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "missing body"));
|
||||
}
|
||||
let update: ModeUpdate = match serde_json::from_slice(&body) {
|
||||
let update: ModeUpdate = match serde_json::from_slice(&buf) {
|
||||
Ok(update) => update,
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to parse mode update");
|
||||
error!("failed to parse mode update: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid json"));
|
||||
}
|
||||
};
|
||||
match state.set_network_mode(update.mode).await {
|
||||
Ok(()) => json_response(&json!({"status": "ok", "mode": update.mode})),
|
||||
Ok(()) => json_response(&ModeUpdateResponse {
|
||||
status: "ok",
|
||||
mode: update.mode,
|
||||
}),
|
||||
Err(err) => {
|
||||
error!(error = %err, "mode update failed");
|
||||
error!("mode update failed: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "mode update failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
(Method::POST, "/reload") => match state.force_reload().await {
|
||||
Ok(()) => json_response(&json!({"status": "reloaded"})),
|
||||
("POST", "/reload") => match state.force_reload().await {
|
||||
Ok(()) => json_response(&ReloadResponse { status: "reloaded" }),
|
||||
Err(err) => {
|
||||
error!(error = %err, "reload failed");
|
||||
error!("reload failed: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "reload failed")
|
||||
}
|
||||
},
|
||||
@@ -106,3 +128,25 @@ async fn handle_admin_request(
|
||||
struct ModeUpdate {
|
||||
mode: NetworkMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct PatternsResponse {
|
||||
allowed: Vec<String>,
|
||||
denied: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct BlockedResponse<T> {
|
||||
blocked: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ModeUpdateResponse {
|
||||
status: &'static str,
|
||||
mode: NetworkMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ReloadResponse {
|
||||
status: &'static str,
|
||||
}
|
||||
|
||||
@@ -1,26 +1,15 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_core::config::default_config_path;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::net::IpAddr;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Config {
|
||||
#[serde(default)]
|
||||
pub network_proxy: NetworkProxyConfig,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
network_proxy: NetworkProxyConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NetworkProxyConfig {
|
||||
#[serde(default)]
|
||||
@@ -50,42 +39,26 @@ impl Default for NetworkProxyConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct NetworkPolicy {
|
||||
#[serde(default, rename = "allowed_domains", alias = "allowedDomains")]
|
||||
#[serde(default)]
|
||||
pub allowed_domains: Vec<String>,
|
||||
#[serde(default, rename = "denied_domains", alias = "deniedDomains")]
|
||||
#[serde(default)]
|
||||
pub denied_domains: Vec<String>,
|
||||
#[serde(default, rename = "allow_unix_sockets", alias = "allowUnixSockets")]
|
||||
#[serde(default)]
|
||||
pub allow_unix_sockets: Vec<String>,
|
||||
#[serde(default, rename = "allow_local_binding", alias = "allowLocalBinding")]
|
||||
#[serde(default)]
|
||||
pub allow_local_binding: bool,
|
||||
}
|
||||
|
||||
impl Default for NetworkPolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed_domains: Vec::new(),
|
||||
denied_domains: Vec::new(),
|
||||
allow_unix_sockets: Vec::new(),
|
||||
allow_local_binding: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum NetworkMode {
|
||||
Limited,
|
||||
#[default]
|
||||
Full,
|
||||
}
|
||||
|
||||
impl Default for NetworkMode {
|
||||
fn default() -> Self {
|
||||
NetworkMode::Full
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MitmConfig {
|
||||
#[serde(default)]
|
||||
@@ -138,10 +111,6 @@ pub struct RuntimeConfig {
|
||||
pub admin_addr: SocketAddr,
|
||||
}
|
||||
|
||||
pub fn default_codex_config_path() -> Result<PathBuf> {
|
||||
default_config_path().context("failed to resolve Codex config path")
|
||||
}
|
||||
|
||||
pub fn resolve_runtime(cfg: &Config) -> RuntimeConfig {
|
||||
let http_addr = resolve_addr(&cfg.network_proxy.proxy_url, 3128);
|
||||
let admin_addr = resolve_addr(&cfg.network_proxy.admin_url, 8080);
|
||||
@@ -155,22 +124,30 @@ pub fn resolve_runtime(cfg: &Config) -> RuntimeConfig {
|
||||
}
|
||||
|
||||
fn resolve_addr(url: &str, default_port: u16) -> SocketAddr {
|
||||
let (host, port) = parse_host_port(url, default_port);
|
||||
let host = if host.eq_ignore_ascii_case("localhost") {
|
||||
let addr_parts = parse_host_port(url, default_port);
|
||||
let host = if addr_parts.host.eq_ignore_ascii_case("localhost") {
|
||||
"127.0.0.1"
|
||||
} else {
|
||||
host
|
||||
addr_parts.host
|
||||
};
|
||||
match host.parse::<IpAddr>() {
|
||||
Ok(ip) => SocketAddr::new(ip, port),
|
||||
Err(_) => SocketAddr::from(([127, 0, 0, 1], port)),
|
||||
Ok(ip) => SocketAddr::new(ip, addr_parts.port),
|
||||
Err(_) => SocketAddr::from(([127, 0, 0, 1], addr_parts.port)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_host_port(url: &str, default_port: u16) -> (&str, u16) {
|
||||
struct SocketAddressParts<'a> {
|
||||
host: &'a str,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
fn parse_host_port(url: &str, default_port: u16) -> SocketAddressParts<'_> {
|
||||
let trimmed = url.trim();
|
||||
if trimmed.is_empty() {
|
||||
return ("127.0.0.1", default_port);
|
||||
return SocketAddressParts {
|
||||
host: "127.0.0.1",
|
||||
port: default_port,
|
||||
};
|
||||
}
|
||||
let without_scheme = trimmed
|
||||
.split_once("://")
|
||||
@@ -182,22 +159,25 @@ fn parse_host_port(url: &str, default_port: u16) -> (&str, u16) {
|
||||
.map(|(_, rest)| rest)
|
||||
.unwrap_or(host_port);
|
||||
|
||||
if host_port.starts_with('[') {
|
||||
if let Some(end) = host_port.find(']') {
|
||||
let host = &host_port[1..end];
|
||||
let port = host_port[end + 1..]
|
||||
.strip_prefix(':')
|
||||
.and_then(|port| port.parse::<u16>().ok())
|
||||
.unwrap_or(default_port);
|
||||
return (host, port);
|
||||
}
|
||||
if host_port.starts_with('[')
|
||||
&& let Some(end) = host_port.find(']')
|
||||
{
|
||||
let host = &host_port[1..end];
|
||||
let port = host_port[end + 1..]
|
||||
.strip_prefix(':')
|
||||
.and_then(|port| port.parse::<u16>().ok())
|
||||
.unwrap_or(default_port);
|
||||
return SocketAddressParts { host, port };
|
||||
}
|
||||
|
||||
if let Some((host, port)) = host_port.rsplit_once(':') {
|
||||
if let Ok(port) = port.parse::<u16>() {
|
||||
return (host, port);
|
||||
}
|
||||
if let Some((host, port)) = host_port.rsplit_once(':')
|
||||
&& let Ok(port) = port.parse::<u16>()
|
||||
{
|
||||
return SocketAddressParts { host, port };
|
||||
}
|
||||
|
||||
(host_port, default_port)
|
||||
SocketAddressParts {
|
||||
host: host_port,
|
||||
port: default_port,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::Layer;
|
||||
use rama::Service;
|
||||
@@ -14,6 +13,7 @@ use rama::http::Request;
|
||||
use rama::http::Response;
|
||||
use rama::http::StatusCode;
|
||||
use rama::http::client::EasyHttpWebClient;
|
||||
use rama::http::dep::http::uri::PathAndQuery;
|
||||
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
use rama::http::layer::upgrade::UpgradeLayer;
|
||||
@@ -41,7 +41,8 @@ pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()
|
||||
let listener = TcpListener::build_with_state(state)
|
||||
.bind(addr)
|
||||
.await
|
||||
.map_err(|err| anyhow!("bind HTTP proxy: {err}"))?;
|
||||
.map_err(|err| anyhow::anyhow!(err))
|
||||
.with_context(|| format!("bind HTTP proxy: {addr}"))?;
|
||||
|
||||
let http_service = HttpServer::auto(rama::rt::Executor::new()).service(
|
||||
(
|
||||
@@ -56,7 +57,7 @@ pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()
|
||||
.into_layer(service_fn(http_plain_proxy)),
|
||||
);
|
||||
|
||||
info!(addr = %addr, "HTTP proxy listening");
|
||||
info!("HTTP proxy listening on {addr}");
|
||||
|
||||
listener.serve(http_service).await;
|
||||
Ok(())
|
||||
@@ -72,7 +73,7 @@ async fn http_connect_accept(
|
||||
{
|
||||
Ok(authority) => authority,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "CONNECT missing authority");
|
||||
warn!("CONNECT missing authority: {err}");
|
||||
return Err(text_response(StatusCode::BAD_REQUEST, "missing authority"));
|
||||
}
|
||||
};
|
||||
@@ -97,23 +98,16 @@ async fn http_connect_accept(
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
reason = %reason,
|
||||
"CONNECT blocked"
|
||||
);
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("CONNECT blocked (client={client}, host={host}, reason={reason})");
|
||||
return Err(blocked_text(&reason));
|
||||
}
|
||||
Ok((false, _)) => {
|
||||
info!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
"CONNECT allowed"
|
||||
);
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("CONNECT allowed (client={client}, host={host})");
|
||||
}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate host");
|
||||
error!("failed to evaluate host for CONNECT {host}: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
@@ -121,7 +115,7 @@ async fn http_connect_accept(
|
||||
let mode = match app_state.network_mode().await {
|
||||
Ok(mode) => mode,
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to read network mode");
|
||||
error!("failed to read network mode: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
@@ -129,7 +123,7 @@ async fn http_connect_accept(
|
||||
let mitm_state = match app_state.mitm_state().await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to load MITM state");
|
||||
error!("failed to load MITM state: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
@@ -145,12 +139,9 @@ async fn http_connect_accept(
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"CONNECT blocked; MITM required for read-only HTTPS in limited mode"
|
||||
"CONNECT blocked; MITM required for read-only HTTPS in limited mode (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Err(blocked_text("mitm_required"));
|
||||
}
|
||||
@@ -186,7 +177,8 @@ async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(),
|
||||
let host = normalize_host(&authority.host().to_string());
|
||||
|
||||
if let Some(mitm_state) = ctx.get::<Arc<mitm::MitmState>>().cloned() {
|
||||
info!(host = %host, port = authority.port(), mode = ?mode, "CONNECT MITM enabled");
|
||||
let port = authority.port();
|
||||
info!("CONNECT MITM enabled (host={host}, port={port}, mode={mode:?})");
|
||||
if let Err(err) = mitm::mitm_tunnel(
|
||||
ctx,
|
||||
upgraded,
|
||||
@@ -197,14 +189,14 @@ async fn http_connect_proxy(ctx: ProxyContext, upgraded: Upgraded) -> Result<(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(error = %err, "MITM tunnel error");
|
||||
warn!("MITM tunnel error: {err}");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let forwarder = Forwarder::ctx();
|
||||
if let Err(err) = forwarder.serve(ctx, upgraded).await {
|
||||
warn!(error = %err, "tunnel error");
|
||||
warn!("tunnel error: {err}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -216,7 +208,7 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
let method_allowed = match app_state.method_allowed(req.method().as_str()).await {
|
||||
Ok(allowed) => allowed,
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate method policy");
|
||||
error!("failed to evaluate method policy: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
@@ -225,21 +217,19 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
.headers()
|
||||
.get("x-unix-socket")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.to_string())
|
||||
.map(ToString::to_string)
|
||||
{
|
||||
if !method_allowed {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
method = %req.method(),
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"unix socket blocked by method policy"
|
||||
"unix socket blocked by method policy (client={client}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Ok(json_blocked("unix-socket", "method_not_allowed"));
|
||||
}
|
||||
|
||||
if !cfg!(target_os = "macos") {
|
||||
warn!(path = %socket_path, "unix socket proxy unsupported on this platform");
|
||||
warn!("unix socket proxy unsupported on this platform (path={socket_path})");
|
||||
return Ok(text_response(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"unix sockets unsupported",
|
||||
@@ -248,15 +238,12 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
|
||||
match app_state.is_unix_socket_allowed(&socket_path).await {
|
||||
Ok(true) => {
|
||||
info!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
path = %socket_path,
|
||||
"unix socket allowed"
|
||||
);
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("unix socket allowed (client={client}, path={socket_path})");
|
||||
match proxy_via_unix_socket(ctx, req, &socket_path).await {
|
||||
Ok(resp) => return Ok(resp),
|
||||
Err(err) => {
|
||||
warn!(error = %err, "unix socket proxy failed");
|
||||
warn!("unix socket proxy failed: {err}");
|
||||
return Ok(text_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"unix socket proxy failed",
|
||||
@@ -265,15 +252,12 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
}
|
||||
}
|
||||
Ok(false) => {
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
path = %socket_path,
|
||||
"unix socket blocked"
|
||||
);
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("unix socket blocked (client={client}, path={socket_path})");
|
||||
return Ok(json_blocked("unix-socket", "not_allowed"));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "unix socket check failed");
|
||||
warn!("unix socket check failed: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
@@ -285,7 +269,7 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
{
|
||||
Ok(authority) => authority,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "missing host");
|
||||
warn!("missing host: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
|
||||
}
|
||||
};
|
||||
@@ -303,17 +287,13 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
reason = %reason,
|
||||
"request blocked"
|
||||
);
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("request blocked (client={client}, host={host}, reason={reason})");
|
||||
return Ok(json_blocked(&host, &reason));
|
||||
}
|
||||
Ok((false, _)) => {}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate host");
|
||||
error!("failed to evaluate host for {host}: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
@@ -329,29 +309,23 @@ async fn http_plain_proxy(mut ctx: ProxyContext, req: Request) -> Result<Respons
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
method = %req.method(),
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"request blocked by method policy"
|
||||
"request blocked by method policy (client={client}, host={host}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Ok(json_blocked(&host, "method_not_allowed"));
|
||||
}
|
||||
|
||||
info!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
method = %req.method(),
|
||||
"request allowed"
|
||||
);
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
info!("request allowed (client={client}, host={host}, method={method})");
|
||||
|
||||
let client = EasyHttpWebClient::default();
|
||||
match client.serve(ctx, req).await {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(err) => {
|
||||
warn!(error = %err, "upstream request failed");
|
||||
warn!("upstream request failed: {err}");
|
||||
Ok(text_response(StatusCode::BAD_GATEWAY, "upstream failure"))
|
||||
}
|
||||
}
|
||||
@@ -377,7 +351,7 @@ async fn proxy_via_unix_socket(
|
||||
let path = parts
|
||||
.uri
|
||||
.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.map(PathAndQuery::as_str)
|
||||
.unwrap_or("/");
|
||||
parts.uri = path
|
||||
.parse()
|
||||
|
||||
@@ -41,8 +41,7 @@ pub async fn run_main(args: Args) -> Result<()> {
|
||||
warn!("allowUnixSockets is macOS-only; requests will be rejected on this platform");
|
||||
}
|
||||
|
||||
let cfg_path = config::default_codex_config_path()?;
|
||||
let state = Arc::new(AppState::new(cfg_path).await?);
|
||||
let state = Arc::new(AppState::new().await?);
|
||||
let runtime = config::resolve_runtime(&state.current_cfg().await?);
|
||||
|
||||
let http_addr: SocketAddr = runtime.http_addr;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#[cfg(feature = "mitm")]
|
||||
mod imp {
|
||||
use crate::config::MitmConfig;
|
||||
use crate::config::NetworkMode;
|
||||
@@ -22,6 +21,7 @@ mod imp {
|
||||
use rama::http::Response;
|
||||
use rama::http::StatusCode;
|
||||
use rama::http::Uri;
|
||||
use rama::http::dep::http::uri::PathAndQuery;
|
||||
use rama::http::header::HOST;
|
||||
use rama::http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama::http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
@@ -44,20 +44,19 @@ mod imp {
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use rcgen::BasicConstraints;
|
||||
use rcgen::Certificate;
|
||||
use rcgen::CertificateParams;
|
||||
use rcgen::DistinguishedName;
|
||||
use rcgen::DnType;
|
||||
use rcgen::ExtendedKeyUsagePurpose;
|
||||
use rcgen::IsCa;
|
||||
use rcgen::KeyPair;
|
||||
use rcgen::KeyUsagePurpose;
|
||||
use rcgen::SanType;
|
||||
use rcgen_rama::BasicConstraints;
|
||||
use rcgen_rama::CertificateParams;
|
||||
use rcgen_rama::DistinguishedName;
|
||||
use rcgen_rama::DnType;
|
||||
use rcgen_rama::ExtendedKeyUsagePurpose;
|
||||
use rcgen_rama::IsCa;
|
||||
use rcgen_rama::Issuer;
|
||||
use rcgen_rama::KeyPair;
|
||||
use rcgen_rama::KeyUsagePurpose;
|
||||
use rcgen_rama::SanType;
|
||||
|
||||
pub struct MitmState {
|
||||
ca_key: KeyPair,
|
||||
ca_cert: Certificate,
|
||||
issuer: Issuer<'static, KeyPair>,
|
||||
upstream: rama::service::BoxService<Arc<AppState>, Request, Response, OpaqueError>,
|
||||
inspect: bool,
|
||||
max_body_bytes: usize,
|
||||
@@ -67,11 +66,8 @@ mod imp {
|
||||
pub fn new(cfg: &MitmConfig) -> Result<Self> {
|
||||
let (ca_cert_pem, ca_key_pem) = load_or_create_ca(cfg)?;
|
||||
let ca_key = KeyPair::from_pem(&ca_key_pem).context("failed to parse CA key")?;
|
||||
let ca_params = CertificateParams::from_ca_cert_pem(&ca_cert_pem)
|
||||
let issuer: Issuer<'static, KeyPair> = Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)
|
||||
.context("failed to parse CA cert")?;
|
||||
let ca_cert = ca_params
|
||||
.self_signed(&ca_key)
|
||||
.context("failed to reconstruct CA cert")?;
|
||||
|
||||
let tls_config = rama::tls::rustls::client::TlsConnectorData::new_http_auto()
|
||||
.context("create upstream TLS config")?;
|
||||
@@ -84,8 +80,7 @@ mod imp {
|
||||
.boxed();
|
||||
|
||||
Ok(Self {
|
||||
ca_key,
|
||||
ca_cert,
|
||||
issuer,
|
||||
upstream,
|
||||
inspect: cfg.inspect,
|
||||
max_body_bytes: cfg.max_body_bytes,
|
||||
@@ -93,8 +88,7 @@ mod imp {
|
||||
}
|
||||
|
||||
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
|
||||
let (cert_pem, key_pem) =
|
||||
issue_host_certificate_pem(host, &self.ca_cert, &self.ca_key)?;
|
||||
let (cert_pem, key_pem) = issue_host_certificate_pem(host, &self.issuer)?;
|
||||
let cert_chain = pemfile::certs(&mut BufReader::new(cert_pem.as_bytes()))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.context("failed to parse host cert PEM")?;
|
||||
@@ -160,7 +154,7 @@ mod imp {
|
||||
let response = match forward_request(ctx, req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "MITM upstream request failed");
|
||||
warn!("MITM upstream request failed: {err}");
|
||||
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
|
||||
}
|
||||
};
|
||||
@@ -201,11 +195,7 @@ mod imp {
|
||||
if let Some(request_host) = extract_request_host(&req) {
|
||||
let normalized = normalize_host(&request_host);
|
||||
if !normalized.is_empty() && normalized != target_host {
|
||||
warn!(
|
||||
target = %target_host,
|
||||
request_host = %normalized,
|
||||
"MITM host mismatch"
|
||||
);
|
||||
warn!("MITM host mismatch (target={target_host}, request_host={normalized})");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "host mismatch"));
|
||||
}
|
||||
}
|
||||
@@ -223,12 +213,7 @@ mod imp {
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
host = %target_host,
|
||||
method = %method,
|
||||
path = %path,
|
||||
mode = ?mode,
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"MITM blocked by method policy"
|
||||
"MITM blocked by method policy (host={target_host}, method={method}, path={path}, mode={mode:?}, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Ok(blocked_text("method_not_allowed"));
|
||||
}
|
||||
@@ -355,27 +340,23 @@ mod imp {
|
||||
|
||||
impl BodyLoggable for RequestLogContext {
|
||||
fn log(self, len: usize, truncated: bool) {
|
||||
let host = self.host;
|
||||
let method = self.method;
|
||||
let path = self.path;
|
||||
info!(
|
||||
host = %self.host,
|
||||
method = %self.method,
|
||||
path = %self.path,
|
||||
body_len = len,
|
||||
truncated = truncated,
|
||||
"MITM inspected request body"
|
||||
"MITM inspected request body (host={host}, method={method}, path={path}, body_len={len}, truncated={truncated})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl BodyLoggable for ResponseLogContext {
|
||||
fn log(self, len: usize, truncated: bool) {
|
||||
let host = self.host;
|
||||
let method = self.method;
|
||||
let path = self.path;
|
||||
let status = self.status;
|
||||
info!(
|
||||
host = %self.host,
|
||||
method = %self.method,
|
||||
path = %self.path,
|
||||
status = %self.status,
|
||||
body_len = len,
|
||||
truncated = truncated,
|
||||
"MITM inspected response body"
|
||||
"MITM inspected response body (host={host}, method={method}, path={path}, status={status}, body_len={len}, truncated={truncated})"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -384,7 +365,7 @@ mod imp {
|
||||
req.headers()
|
||||
.get(HOST)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.to_string())
|
||||
.map(ToString::to_string)
|
||||
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
|
||||
}
|
||||
|
||||
@@ -410,15 +391,14 @@ mod imp {
|
||||
|
||||
fn path_and_query(uri: &Uri) -> String {
|
||||
uri.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.map(PathAndQuery::as_str)
|
||||
.unwrap_or("/")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn issue_host_certificate_pem(
|
||||
host: &str,
|
||||
ca_cert: &Certificate,
|
||||
ca_key: &KeyPair,
|
||||
issuer: &Issuer<'_, KeyPair>,
|
||||
) -> Result<(String, String)> {
|
||||
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
let mut params = CertificateParams::new(Vec::new())
|
||||
@@ -436,10 +416,10 @@ mod imp {
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
|
||||
let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
|
||||
let key_pair = KeyPair::generate_for(&rcgen_rama::PKCS_ECDSA_P256_SHA256)
|
||||
.map_err(|err| anyhow!("failed to generate host key pair: {err}"))?;
|
||||
let cert = params
|
||||
.signed_by(&key_pair, ca_cert, ca_key)
|
||||
.signed_by(&key_pair, issuer)
|
||||
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
|
||||
|
||||
Ok((cert.pem(), key_pair.serialize_pem()))
|
||||
@@ -472,11 +452,9 @@ mod imp {
|
||||
let (cert_pem, key_pem) = generate_ca()?;
|
||||
write_private_file(cert_path, cert_pem.as_bytes(), 0o644)?;
|
||||
write_private_file(key_path, key_pem.as_bytes(), 0o600)?;
|
||||
info!(
|
||||
cert_path = %cert_path.display(),
|
||||
key_path = %key_path.display(),
|
||||
"generated MITM CA"
|
||||
);
|
||||
let cert_path = cert_path.display();
|
||||
let key_path = key_path.display();
|
||||
info!("generated MITM CA (cert_path={cert_path}, key_path={key_path})");
|
||||
Ok((cert_pem, key_pem))
|
||||
}
|
||||
|
||||
@@ -492,7 +470,7 @@ mod imp {
|
||||
dn.push(DnType::CommonName, "network_proxy MITM CA");
|
||||
params.distinguished_name = dn;
|
||||
|
||||
let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
|
||||
let key_pair = KeyPair::generate_for(&rcgen_rama::PKCS_ECDSA_P256_SHA256)
|
||||
.map_err(|err| anyhow!("failed to generate CA key pair: {err}"))?;
|
||||
let cert = params
|
||||
.self_signed(&key_pair)
|
||||
@@ -557,45 +535,4 @@ mod imp {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "mitm"))]
|
||||
mod imp {
|
||||
use crate::config::MitmConfig;
|
||||
use crate::config::NetworkMode;
|
||||
use crate::state::AppState;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use rama::Context as RamaContext;
|
||||
use rama::http::layer::upgrade::Upgraded;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MitmState;
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl MitmState {
|
||||
pub fn new(_cfg: &MitmConfig) -> Result<Self> {
|
||||
Err(anyhow!("MITM feature disabled at build time"))
|
||||
}
|
||||
|
||||
pub fn inspect_enabled(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn max_body_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mitm_tunnel(
|
||||
_ctx: RamaContext<Arc<AppState>>,
|
||||
_upgraded: Upgraded,
|
||||
_host: &str,
|
||||
_port: u16,
|
||||
_mode: NetworkMode,
|
||||
_state: Arc<MitmState>,
|
||||
) -> Result<()> {
|
||||
Err(anyhow!("MITM feature disabled at build time"))
|
||||
}
|
||||
}
|
||||
|
||||
pub use imp::*;
|
||||
|
||||
@@ -21,10 +21,10 @@ pub fn is_loopback_host(host: &str) -> bool {
|
||||
|
||||
pub fn normalize_host(host: &str) -> String {
|
||||
let host = host.trim();
|
||||
if host.starts_with('[') {
|
||||
if let Some(end) = host.find(']') {
|
||||
return host[1..end].to_ascii_lowercase();
|
||||
}
|
||||
if host.starts_with('[')
|
||||
&& let Some(end) = host.find(']')
|
||||
{
|
||||
return host[1..end].to_ascii_lowercase();
|
||||
}
|
||||
host.split(':').next().unwrap_or("").to_ascii_lowercase()
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use hyper::Body;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use rama::http::Body;
|
||||
use rama::http::Response;
|
||||
use rama::http::StatusCode;
|
||||
use serde::Serialize;
|
||||
|
||||
pub fn text_response(status: StatusCode, body: &str) -> Response<Body> {
|
||||
pub fn text_response(status: StatusCode, body: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "text/plain")
|
||||
@@ -11,7 +11,7 @@ pub fn text_response(status: StatusCode, body: &str) -> Response<Body> {
|
||||
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
|
||||
}
|
||||
|
||||
pub fn json_response<T: Serialize>(value: &T) -> Response<Body> {
|
||||
pub fn json_response<T: Serialize>(value: &T) -> Response {
|
||||
let body = match serde_json::to_string(value) {
|
||||
Ok(body) => body,
|
||||
Err(_) => "{}".to_string(),
|
||||
|
||||
@@ -26,24 +26,21 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
.await
|
||||
.map_err(|err| anyhow!("bind SOCKS5 proxy: {err}"))?;
|
||||
|
||||
info!(addr = %addr, "SOCKS5 proxy listening");
|
||||
info!("SOCKS5 proxy listening on {addr}");
|
||||
|
||||
match state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
info!(
|
||||
mode = "limited",
|
||||
"SOCKS5 is blocked in limited mode; set mode=\"full\" to allow SOCKS5"
|
||||
);
|
||||
info!("SOCKS5 is blocked in limited mode; set mode=\"full\" to allow SOCKS5");
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to read network mode");
|
||||
warn!("failed to read network mode: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
let tcp_connector = TcpConnector::default();
|
||||
let policy_tcp_connector =
|
||||
service_fn(move |ctx: RamaContext<Arc<AppState>>, req: TcpRequest| {
|
||||
let policy_tcp_connector = service_fn(
|
||||
move |ctx: RamaContext<Arc<AppState>>, req: TcpRequest| {
|
||||
let tcp_connector = tcp_connector.clone();
|
||||
async move {
|
||||
let app_state = ctx.state().clone();
|
||||
@@ -66,12 +63,9 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
mode = "limited",
|
||||
allowed_methods = "GET, HEAD, OPTIONS",
|
||||
"SOCKS blocked by method policy"
|
||||
"SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Err(
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
|
||||
@@ -79,8 +73,8 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate method policy");
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "proxy error").into());
|
||||
error!("failed to evaluate method policy: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,33 +90,26 @@ pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
reason = %reason,
|
||||
"SOCKS blocked"
|
||||
);
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("SOCKS blocked (client={client}, host={host}, reason={reason})");
|
||||
return Err(
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
|
||||
);
|
||||
}
|
||||
Ok((false, _)) => {
|
||||
info!(
|
||||
client = %client.as_deref().unwrap_or_default(),
|
||||
host = %host,
|
||||
port = port,
|
||||
"SOCKS allowed"
|
||||
);
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("SOCKS allowed (client={client}, host={host}, port={port})");
|
||||
}
|
||||
Err(err) => {
|
||||
error!(error = %err, "failed to evaluate host");
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "proxy error").into());
|
||||
error!("failed to evaluate host: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
}
|
||||
|
||||
tcp_connector.serve(ctx, req).await
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector);
|
||||
let socks_acceptor = Socks5Acceptor::new().with_connector(socks_connector);
|
||||
|
||||
@@ -6,10 +6,15 @@ use crate::policy::is_loopback_host;
|
||||
use crate::policy::method_allowed;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
use codex_core::config::CONFIG_TOML_FILE;
|
||||
use codex_core::config::ConfigBuilder;
|
||||
use codex_core::config::Constrained;
|
||||
use codex_core::config::ConstraintError;
|
||||
use globset::GlobBuilder;
|
||||
use globset::GlobSet;
|
||||
use globset::GlobSetBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
@@ -17,7 +22,7 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
use time::OffsetDateTime;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
@@ -58,7 +63,7 @@ impl BlockedRequest {
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConfigState {
|
||||
cfg: Config,
|
||||
config: Config,
|
||||
mtime: Option<SystemTime>,
|
||||
allow_set: GlobSet,
|
||||
deny_set: GlobSet,
|
||||
@@ -73,8 +78,8 @@ pub struct AppState {
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new(cfg_path: PathBuf) -> Result<Self> {
|
||||
let cfg_state = build_config_state(cfg_path)?;
|
||||
pub async fn new() -> Result<Self> {
|
||||
let cfg_state = build_config_state().await?;
|
||||
Ok(Self {
|
||||
state: Arc::new(RwLock::new(cfg_state)),
|
||||
})
|
||||
@@ -83,33 +88,34 @@ impl AppState {
|
||||
pub async fn current_cfg(&self) -> Result<Config> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.cfg.clone())
|
||||
Ok(guard.config.clone())
|
||||
}
|
||||
|
||||
pub async fn current_patterns(&self) -> Result<(Vec<String>, Vec<String>)> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok((
|
||||
guard.cfg.network_proxy.policy.allowed_domains.clone(),
|
||||
guard.cfg.network_proxy.policy.denied_domains.clone(),
|
||||
guard.config.network_proxy.policy.allowed_domains.clone(),
|
||||
guard.config.network_proxy.policy.denied_domains.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn force_reload(&self) -> Result<()> {
|
||||
let mut guard = self.state.write().await;
|
||||
let previous_cfg = guard.cfg.clone();
|
||||
let previous_cfg = guard.config.clone();
|
||||
let blocked = guard.blocked.clone();
|
||||
let cfg_path = guard.cfg_path.clone();
|
||||
match build_config_state(cfg_path.clone()) {
|
||||
match build_config_state().await {
|
||||
Ok(mut new_state) => {
|
||||
log_policy_changes(&previous_cfg, &new_state.cfg);
|
||||
log_policy_changes(&previous_cfg, &new_state.config);
|
||||
new_state.blocked = blocked;
|
||||
*guard = new_state;
|
||||
info!(path = %cfg_path.display(), "reloaded config");
|
||||
let path = guard.cfg_path.display();
|
||||
info!("reloaded config from {path}");
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, path = %cfg_path.display(), "failed to reload config; keeping previous config");
|
||||
let path = guard.cfg_path.display();
|
||||
warn!("failed to reload config from {path}: {err}; keeping previous config");
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
@@ -123,12 +129,12 @@ impl AppState {
|
||||
}
|
||||
let is_loopback = is_loopback_host(host);
|
||||
if is_loopback
|
||||
&& !guard.cfg.network_proxy.policy.allow_local_binding
|
||||
&& !guard.config.network_proxy.policy.allow_local_binding
|
||||
&& !guard.allow_set.is_match(host)
|
||||
{
|
||||
return Ok((true, "not_allowed_local".to_string()));
|
||||
}
|
||||
if guard.cfg.network_proxy.policy.allowed_domains.is_empty()
|
||||
if guard.config.network_proxy.policy.allowed_domains.is_empty()
|
||||
|| !guard.allow_set.is_match(host)
|
||||
{
|
||||
return Ok((true, "not_allowed".to_string()));
|
||||
@@ -157,7 +163,7 @@ impl AppState {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard
|
||||
.cfg
|
||||
.config
|
||||
.network_proxy
|
||||
.policy
|
||||
.allow_unix_sockets
|
||||
@@ -168,20 +174,20 @@ impl AppState {
|
||||
pub async fn method_allowed(&self, method: &str) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(method_allowed(guard.cfg.network_proxy.mode, method))
|
||||
Ok(method_allowed(guard.config.network_proxy.mode, method))
|
||||
}
|
||||
|
||||
pub async fn network_mode(&self) -> Result<NetworkMode> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.cfg.network_proxy.mode)
|
||||
Ok(guard.config.network_proxy.mode)
|
||||
}
|
||||
|
||||
pub async fn set_network_mode(&self, mode: NetworkMode) -> Result<()> {
|
||||
self.reload_if_needed().await?;
|
||||
let mut guard = self.state.write().await;
|
||||
guard.cfg.network_proxy.mode = mode;
|
||||
info!(mode = ?mode, "updated network mode");
|
||||
guard.config.network_proxy.mode = mode;
|
||||
info!("updated network mode to {mode:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -195,7 +201,9 @@ impl AppState {
|
||||
let needs_reload = {
|
||||
let guard = self.state.read().await;
|
||||
if !guard.cfg_path.exists() {
|
||||
true
|
||||
// If the config file is missing, only reload when it *used to* exist (mtime set).
|
||||
// This avoids forcing a reload on every request when running with the default config.
|
||||
guard.mtime.is_some()
|
||||
} else {
|
||||
let metadata = std::fs::metadata(&guard.cfg_path).ok();
|
||||
match (metadata.and_then(|m| m.modified().ok()), guard.mtime) {
|
||||
@@ -214,28 +222,32 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_config_state(cfg_path: PathBuf) -> Result<ConfigState> {
|
||||
let mut cfg = if cfg_path.exists() {
|
||||
load_config_from_path(&cfg_path).with_context(|| {
|
||||
format!(
|
||||
"failed to load config from {}",
|
||||
cfg_path.as_path().display()
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
Config::default()
|
||||
};
|
||||
resolve_mitm_paths(&mut cfg, &cfg_path);
|
||||
async fn build_config_state() -> Result<ConfigState> {
|
||||
let codex_cfg = ConfigBuilder::default()
|
||||
.build()
|
||||
.await
|
||||
.context("failed to load Codex config")?;
|
||||
|
||||
let cfg_path = codex_cfg.codex_home.join(CONFIG_TOML_FILE);
|
||||
|
||||
let merged_toml = codex_cfg.config_layer_stack.effective_config();
|
||||
let mut config: Config = merged_toml
|
||||
.try_into()
|
||||
.context("failed to deserialize network proxy config")?;
|
||||
|
||||
enforce_trusted_constraints(&codex_cfg.config_layer_stack, &config)?;
|
||||
|
||||
resolve_mitm_paths(&mut config, &cfg_path);
|
||||
let mtime = cfg_path.metadata().and_then(|m| m.modified()).ok();
|
||||
let deny_set = compile_globset(&cfg.network_proxy.policy.denied_domains)?;
|
||||
let allow_set = compile_globset(&cfg.network_proxy.policy.allowed_domains)?;
|
||||
let mitm = if cfg.network_proxy.mitm.enabled {
|
||||
build_mitm_state(&cfg.network_proxy.mitm)?
|
||||
let deny_set = compile_globset(&config.network_proxy.policy.denied_domains)?;
|
||||
let allow_set = compile_globset(&config.network_proxy.policy.allowed_domains)?;
|
||||
let mitm = if config.network_proxy.mitm.enabled {
|
||||
build_mitm_state(&config.network_proxy.mitm)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(ConfigState {
|
||||
cfg,
|
||||
config,
|
||||
mtime,
|
||||
allow_set,
|
||||
deny_set,
|
||||
@@ -245,25 +257,248 @@ fn build_config_state(cfg_path: PathBuf) -> Result<ConfigState> {
|
||||
})
|
||||
}
|
||||
|
||||
fn resolve_mitm_paths(cfg: &mut Config, cfg_path: &Path) {
|
||||
fn resolve_mitm_paths(config: &mut Config, cfg_path: &Path) {
|
||||
let base = cfg_path.parent().unwrap_or_else(|| Path::new("."));
|
||||
if cfg.network_proxy.mitm.ca_cert_path.is_relative() {
|
||||
cfg.network_proxy.mitm.ca_cert_path = base.join(&cfg.network_proxy.mitm.ca_cert_path);
|
||||
if config.network_proxy.mitm.ca_cert_path.is_relative() {
|
||||
config.network_proxy.mitm.ca_cert_path = base.join(&config.network_proxy.mitm.ca_cert_path);
|
||||
}
|
||||
if cfg.network_proxy.mitm.ca_key_path.is_relative() {
|
||||
cfg.network_proxy.mitm.ca_key_path = base.join(&cfg.network_proxy.mitm.ca_key_path);
|
||||
if config.network_proxy.mitm.ca_key_path.is_relative() {
|
||||
config.network_proxy.mitm.ca_key_path = base.join(&config.network_proxy.mitm.ca_key_path);
|
||||
}
|
||||
}
|
||||
|
||||
fn build_mitm_state(_cfg: &MitmConfig) -> Result<Option<Arc<MitmState>>> {
|
||||
#[cfg(feature = "mitm")]
|
||||
fn build_mitm_state(config: &MitmConfig) -> Result<Option<Arc<MitmState>>> {
|
||||
Ok(Some(Arc::new(MitmState::new(config)?)))
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialConfig {
|
||||
#[serde(default)]
|
||||
network_proxy: PartialNetworkProxyConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialNetworkProxyConfig {
|
||||
enabled: Option<bool>,
|
||||
mode: Option<NetworkMode>,
|
||||
#[serde(default)]
|
||||
policy: PartialNetworkPolicy,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialNetworkPolicy {
|
||||
#[serde(default)]
|
||||
allowed_domains: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
denied_domains: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
allow_unix_sockets: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
allow_local_binding: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct NetworkProxyConstraints {
|
||||
enabled: Option<bool>,
|
||||
mode: Option<NetworkMode>,
|
||||
allowed_domains: Option<Vec<String>>,
|
||||
denied_domains: Option<Vec<String>>,
|
||||
allow_unix_sockets: Option<Vec<String>>,
|
||||
allow_local_binding: Option<bool>,
|
||||
}
|
||||
|
||||
fn enforce_trusted_constraints(
|
||||
layers: &codex_core::config_loader::ConfigLayerStack,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let constraints = network_proxy_constraints_from_trusted_layers(layers)?;
|
||||
validate_policy_against_constraints(config, &constraints)
|
||||
.context("network proxy constraints")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn network_proxy_constraints_from_trusted_layers(
|
||||
layers: &codex_core::config_loader::ConfigLayerStack,
|
||||
) -> Result<NetworkProxyConstraints> {
|
||||
let mut constraints = NetworkProxyConstraints::default();
|
||||
for layer in layers
|
||||
.get_layers(codex_core::config_loader::ConfigLayerStackOrdering::LowestPrecedenceFirst)
|
||||
{
|
||||
return Ok(Some(Arc::new(MitmState::new(_cfg)?)));
|
||||
if is_user_controlled_layer(&layer.name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let partial: PartialConfig = layer
|
||||
.config
|
||||
.clone()
|
||||
.try_into()
|
||||
.context("failed to deserialize trusted config layer")?;
|
||||
|
||||
if let Some(enabled) = partial.network_proxy.enabled {
|
||||
constraints.enabled = Some(enabled);
|
||||
}
|
||||
if let Some(mode) = partial.network_proxy.mode {
|
||||
constraints.mode = Some(mode);
|
||||
}
|
||||
|
||||
if let Some(allowed_domains) = partial.network_proxy.policy.allowed_domains {
|
||||
constraints.allowed_domains = Some(allowed_domains);
|
||||
}
|
||||
if let Some(denied_domains) = partial.network_proxy.policy.denied_domains {
|
||||
constraints.denied_domains = Some(denied_domains);
|
||||
}
|
||||
if let Some(allow_unix_sockets) = partial.network_proxy.policy.allow_unix_sockets {
|
||||
constraints.allow_unix_sockets = Some(allow_unix_sockets);
|
||||
}
|
||||
if let Some(allow_local_binding) = partial.network_proxy.policy.allow_local_binding {
|
||||
constraints.allow_local_binding = Some(allow_local_binding);
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "mitm"))]
|
||||
{
|
||||
warn!("MITM enabled in config but binary built without mitm feature");
|
||||
Ok(None)
|
||||
Ok(constraints)
|
||||
}
|
||||
|
||||
fn is_user_controlled_layer(layer: &ConfigLayerSource) -> bool {
|
||||
matches!(
|
||||
layer,
|
||||
ConfigLayerSource::User { .. }
|
||||
| ConfigLayerSource::Project { .. }
|
||||
| ConfigLayerSource::SessionFlags
|
||||
)
|
||||
}
|
||||
|
||||
fn validate_policy_against_constraints(
|
||||
config: &Config,
|
||||
constraints: &NetworkProxyConstraints,
|
||||
) -> std::result::Result<(), ConstraintError> {
|
||||
let enabled = config.network_proxy.enabled;
|
||||
if let Some(max_enabled) = constraints.enabled {
|
||||
let _ = Constrained::new(enabled, move |candidate| {
|
||||
if *candidate && !max_enabled {
|
||||
Err(ConstraintError::invalid_value(
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
if let Some(max_mode) = constraints.mode {
|
||||
let _ = Constrained::new(config.network_proxy.mode, move |candidate| {
|
||||
if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
|
||||
Err(ConstraintError::invalid_value(
|
||||
format!("{candidate:?}"),
|
||||
format!("{max_mode:?} or more restrictive"),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
if let Some(allow_local_binding) = constraints.allow_local_binding {
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allow_local_binding,
|
||||
move |candidate| {
|
||||
if *candidate && !allow_local_binding {
|
||||
Err(ConstraintError::invalid_value(
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(allowed_domains) = &constraints.allowed_domains {
|
||||
let allowed_set: HashSet<String> = allowed_domains
|
||||
.iter()
|
||||
.map(|s| s.to_ascii_lowercase())
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allowed_domains.clone(),
|
||||
move |candidate| {
|
||||
let mut invalid = Vec::new();
|
||||
for entry in candidate {
|
||||
if !allowed_set.contains(&entry.to_ascii_lowercase()) {
|
||||
invalid.push(entry.clone());
|
||||
}
|
||||
}
|
||||
if invalid.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
format!("{invalid:?}"),
|
||||
"subset of managed allowed_domains",
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(denied_domains) = &constraints.denied_domains {
|
||||
let required_set: HashSet<String> = denied_domains
|
||||
.iter()
|
||||
.map(|s| s.to_ascii_lowercase())
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.denied_domains.clone(),
|
||||
move |candidate| {
|
||||
let candidate_set: HashSet<String> =
|
||||
candidate.iter().map(|s| s.to_ascii_lowercase()).collect();
|
||||
let missing: Vec<String> = required_set
|
||||
.iter()
|
||||
.filter(|entry| !candidate_set.contains(*entry))
|
||||
.cloned()
|
||||
.collect();
|
||||
if missing.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
"missing managed denied_domains entries",
|
||||
format!("{missing:?}"),
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(allow_unix_sockets) = &constraints.allow_unix_sockets {
|
||||
let allowed_set: HashSet<String> = allow_unix_sockets
|
||||
.iter()
|
||||
.map(|s| s.to_ascii_lowercase())
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allow_unix_sockets.clone(),
|
||||
move |candidate| {
|
||||
let mut invalid = Vec::new();
|
||||
for entry in candidate {
|
||||
if !allowed_set.contains(&entry.to_ascii_lowercase()) {
|
||||
invalid.push(entry.clone());
|
||||
}
|
||||
}
|
||||
if invalid.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConstraintError::invalid_value(
|
||||
format!("{invalid:?}"),
|
||||
"subset of managed allow_unix_sockets",
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn network_mode_rank(mode: NetworkMode) -> u8 {
|
||||
match mode {
|
||||
NetworkMode::Limited => 0,
|
||||
NetworkMode::Full => 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,7 +552,7 @@ fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]
|
||||
for entry in next {
|
||||
let key = entry.to_ascii_lowercase();
|
||||
if seen_next.insert(key.clone()) && !previous_set.contains(&key) {
|
||||
info!(list = list_name, entry = %entry, "config entry added");
|
||||
info!("config entry added to {list_name}: {entry}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,20 +560,11 @@ fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]
|
||||
for entry in previous {
|
||||
let key = entry.to_ascii_lowercase();
|
||||
if seen_previous.insert(key.clone()) && !next_set.contains(&key) {
|
||||
info!(list = list_name, entry = %entry, "config entry removed");
|
||||
info!("config entry removed from {list_name}: {entry}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unix_timestamp() -> i64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|duration| duration.as_secs() as i64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn load_config_from_path(path: &Path) -> Result<Config> {
|
||||
let raw = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("unable to read config file {}", path.display()))?;
|
||||
toml::from_str(&raw).map_err(|err| anyhow!("unable to parse config: {err}"))
|
||||
OffsetDateTime::now_utc().unix_timestamp()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user