mirror of
https://github.com/openai/codex.git
synced 2026-02-01 14:44:17 +00:00
feat: introducing a network sandbox proxy (#8442)
This add a new crate, `codex-network-proxy`, a local network proxy service used by Codex to enforce fine-grained network policy (domain allow/deny) and to surface blocked network events for interactive approvals. - New crate: `codex-rs/network-proxy/` (`codex-network-proxy` binary + library) - Core capabilities: - HTTP proxy support (including CONNECT tunneling) - SOCKS5 proxy support (in the later PR) - policy evaluation (allowed/denied domain lists; denylist wins; wildcard support) - small admin API for polling/reload/mode changes - optional MITM support for HTTPS CONNECT to enforce “limited mode” method restrictions (later PR) Will follow up integration with codex in subsequent PRs. ## Testing - `cd codex-rs && cargo build -p codex-network-proxy` - `cd codex-rs && cargo run -p codex-network-proxy -- proxy`
This commit is contained in:
61
.github/scripts/install-musl-build-tools.sh
vendored
Normal file
61
.github/scripts/install-musl-build-tools.sh
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
: "${TARGET:?TARGET environment variable is required}"
|
||||
: "${GITHUB_ENV:?GITHUB_ENV environment variable is required}"
|
||||
|
||||
apt_update_args=()
|
||||
if [[ -n "${APT_UPDATE_ARGS:-}" ]]; then
|
||||
# shellcheck disable=SC2206
|
||||
apt_update_args=(${APT_UPDATE_ARGS})
|
||||
fi
|
||||
|
||||
apt_install_args=()
|
||||
if [[ -n "${APT_INSTALL_ARGS:-}" ]]; then
|
||||
# shellcheck disable=SC2206
|
||||
apt_install_args=(${APT_INSTALL_ARGS})
|
||||
fi
|
||||
|
||||
sudo apt-get update "${apt_update_args[@]}"
|
||||
sudo apt-get install -y "${apt_install_args[@]}" musl-tools pkg-config g++ clang libc++-dev libc++abi-dev lld
|
||||
|
||||
case "${TARGET}" in
|
||||
x86_64-unknown-linux-musl)
|
||||
arch="x86_64"
|
||||
;;
|
||||
aarch64-unknown-linux-musl)
|
||||
arch="aarch64"
|
||||
;;
|
||||
*)
|
||||
echo "Unexpected musl target: ${TARGET}" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
if command -v clang++ >/dev/null; then
|
||||
cxx="$(command -v clang++)"
|
||||
echo "CXXFLAGS=--target=${TARGET} -stdlib=libc++ -pthread" >> "$GITHUB_ENV"
|
||||
echo "CFLAGS=--target=${TARGET} -pthread" >> "$GITHUB_ENV"
|
||||
if command -v clang >/dev/null; then
|
||||
cc="$(command -v clang)"
|
||||
echo "CC=${cc}" >> "$GITHUB_ENV"
|
||||
echo "TARGET_CC=${cc}" >> "$GITHUB_ENV"
|
||||
target_cc_var="CC_${TARGET}"
|
||||
target_cc_var="${target_cc_var//-/_}"
|
||||
echo "${target_cc_var}=${cc}" >> "$GITHUB_ENV"
|
||||
fi
|
||||
elif command -v "${arch}-linux-musl-g++" >/dev/null; then
|
||||
cxx="$(command -v "${arch}-linux-musl-g++")"
|
||||
elif command -v musl-g++ >/dev/null; then
|
||||
cxx="$(command -v musl-g++)"
|
||||
elif command -v musl-gcc >/dev/null; then
|
||||
cxx="$(command -v musl-gcc)"
|
||||
echo "CFLAGS=-pthread" >> "$GITHUB_ENV"
|
||||
else
|
||||
echo "musl g++ not found after install; arch=${arch}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "CXX=${cxx}" >> "$GITHUB_ENV"
|
||||
echo "CMAKE_CXX_COMPILER=${cxx}" >> "$GITHUB_ENV"
|
||||
echo "CMAKE_ARGS=-DCMAKE_HAVE_THREADS_LIBRARY=1 -DCMAKE_USE_PTHREADS_INIT=1 -DCMAKE_THREAD_LIBS_INIT=-pthread -DTHREADS_PREFER_PTHREAD_FLAG=ON" >> "$GITHUB_ENV"
|
||||
8
.github/workflows/rust-ci.yml
vendored
8
.github/workflows/rust-ci.yml
vendored
@@ -265,11 +265,11 @@ jobs:
|
||||
name: Install musl build tools
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
TARGET: ${{ matrix.target }}
|
||||
APT_UPDATE_ARGS: -o Acquire::Retries=3
|
||||
APT_INSTALL_ARGS: --no-install-recommends
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
sudo apt-get -y update -o Acquire::Retries=3
|
||||
sudo apt-get -y install --no-install-recommends musl-tools pkg-config
|
||||
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
|
||||
|
||||
- name: Install cargo-chef
|
||||
if: ${{ matrix.profile == 'release' }}
|
||||
|
||||
6
.github/workflows/rust-release.yml
vendored
6
.github/workflows/rust-release.yml
vendored
@@ -106,9 +106,9 @@ jobs:
|
||||
|
||||
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
|
||||
name: Install musl build tools
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y musl-tools pkg-config
|
||||
env:
|
||||
TARGET: ${{ matrix.target }}
|
||||
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
|
||||
|
||||
- name: Cargo build
|
||||
shell: bash
|
||||
|
||||
6
.github/workflows/shell-tool-mcp.yml
vendored
6
.github/workflows/shell-tool-mcp.yml
vendored
@@ -99,9 +99,9 @@ jobs:
|
||||
|
||||
- if: ${{ matrix.install_musl }}
|
||||
name: Install musl build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y musl-tools pkg-config
|
||||
env:
|
||||
TARGET: ${{ matrix.target }}
|
||||
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
|
||||
|
||||
- name: Build exec server binaries
|
||||
run: cargo build --release --target ${{ matrix.target }} --bin codex-exec-mcp-server --bin codex-execve-wrapper
|
||||
|
||||
846
codex-rs/Cargo.lock
generated
846
codex-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -27,6 +27,7 @@ members = [
|
||||
"login",
|
||||
"mcp-server",
|
||||
"mcp-types",
|
||||
"network-proxy",
|
||||
"ollama",
|
||||
"process-hardening",
|
||||
"protocol",
|
||||
@@ -136,6 +137,7 @@ env-flags = "0.1.1"
|
||||
env_logger = "0.11.5"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = { version = "0.3", default-features = false }
|
||||
globset = "0.4"
|
||||
http = "1.3.1"
|
||||
icu_decimal = "2.1"
|
||||
icu_locale_core = "2.1"
|
||||
|
||||
@@ -73,6 +73,7 @@ ignore = [
|
||||
{ id = "RUSTSEC-2024-0388", reason = "derivative is unmaintained; pulled in via starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2025-0057", reason = "fxhash is unmaintained; pulled in via starlark_map/starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2024-0436", reason = "paste is unmaintained; pulled in via ratatui/rmcp/starlark used by tui/execpolicy; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2025-0134", reason = "rustls-pemfile is unmaintained; pulled in via rama-tls-rustls used by codex-network-proxy; no safe upgrade until rama removes the dependency" },
|
||||
# TODO(joshka, nornagon): remove this exception when once we update the ratatui fork to a version that uses lru 0.13+.
|
||||
{ id = "RUSTSEC-2026-0002", reason = "lru 0.12.5 is pulled in via ratatui fork; cannot upgrade until the fork is updated" },
|
||||
]
|
||||
|
||||
45
codex-rs/network-proxy/Cargo.toml
Normal file
45
codex-rs/network-proxy/Cargo.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
[package]
|
||||
name = "codex-network-proxy"
|
||||
edition = "2024"
|
||||
version = { workspace = true }
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "codex-network-proxy"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
name = "codex_network_proxy"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
globset = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
time = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["fmt"] }
|
||||
url = { workspace = true }
|
||||
rama-core = { version = "=0.3.0-alpha.4" }
|
||||
rama-http = { version = "=0.3.0-alpha.4" }
|
||||
rama-http-backend = { version = "=0.3.0-alpha.4", features = ["tls"] }
|
||||
rama-net = { version = "=0.3.0-alpha.4", features = ["http", "tls"] }
|
||||
rama-tcp = { version = "=0.3.0-alpha.4", features = ["http"] }
|
||||
rama-tls-boring = { version = "=0.3.0-alpha.4", features = ["http"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
[target.'cfg(target_family = "unix")'.dependencies]
|
||||
rama-unix = { version = "=0.3.0-alpha.4" }
|
||||
169
codex-rs/network-proxy/README.md
Normal file
169
codex-rs/network-proxy/README.md
Normal file
@@ -0,0 +1,169 @@
|
||||
# codex-network-proxy
|
||||
|
||||
`codex-network-proxy` is Codex's local network policy enforcement proxy. It runs:
|
||||
|
||||
- an HTTP proxy (default `127.0.0.1:3128`)
|
||||
- an admin HTTP API (default `127.0.0.1:8080`)
|
||||
|
||||
It enforces an allow/deny policy and a "limited" mode intended for read-only network access.
|
||||
|
||||
## Quickstart
|
||||
|
||||
### 1) Configure
|
||||
|
||||
`codex-network-proxy` reads from Codex's merged `config.toml` (via `codex-core` config loading).
|
||||
|
||||
Example config:
|
||||
|
||||
```toml
|
||||
[network_proxy]
|
||||
enabled = true
|
||||
proxy_url = "http://127.0.0.1:3128"
|
||||
admin_url = "http://127.0.0.1:8080"
|
||||
# When `enabled` is false, the proxy no-ops and does not bind listeners.
|
||||
# When true, respect HTTP(S)_PROXY/ALL_PROXY for upstream requests (HTTP(S) proxies only),
|
||||
# including CONNECT tunnels in full mode.
|
||||
allow_upstream_proxy = false
|
||||
# By default, non-loopback binds are clamped to loopback for safety.
|
||||
# If you want to expose these listeners beyond localhost, you must opt in explicitly.
|
||||
dangerously_allow_non_loopback_proxy = false
|
||||
dangerously_allow_non_loopback_admin = false
|
||||
mode = "limited" # or "full"
|
||||
|
||||
[network_proxy.policy]
|
||||
# Hosts must match the allowlist (unless denied).
|
||||
# If `allowed_domains` is empty, the proxy blocks requests until an allowlist is configured.
|
||||
allowed_domains = ["*.openai.com"]
|
||||
denied_domains = ["evil.example"]
|
||||
|
||||
# If false, local/private networking is rejected. Explicit allowlisting of local IP literals
|
||||
# (or `localhost`) is required to permit them.
|
||||
# Hostnames that resolve to local/private IPs are still blocked even if allowlisted.
|
||||
allow_local_binding = false
|
||||
|
||||
# macOS-only: allows proxying to a unix socket when request includes `x-unix-socket: /path`.
|
||||
allow_unix_sockets = ["/tmp/example.sock"]
|
||||
```
|
||||
|
||||
### 2) Run the proxy
|
||||
|
||||
```bash
|
||||
cargo run -p codex-network-proxy --
|
||||
```
|
||||
|
||||
### 3) Point a client at it
|
||||
|
||||
For HTTP(S) traffic:
|
||||
|
||||
```bash
|
||||
export HTTP_PROXY="http://127.0.0.1:3128"
|
||||
export HTTPS_PROXY="http://127.0.0.1:3128"
|
||||
```
|
||||
|
||||
### 4) Understand blocks / debugging
|
||||
|
||||
When a request is blocked, the proxy responds with `403` and includes:
|
||||
|
||||
- `x-proxy-error`: one of:
|
||||
- `blocked-by-allowlist`
|
||||
- `blocked-by-denylist`
|
||||
- `blocked-by-method-policy`
|
||||
- `blocked-by-policy`
|
||||
|
||||
In "limited" mode, only `GET`, `HEAD`, and `OPTIONS` are allowed for plain HTTP. HTTPS `CONNECT`
|
||||
remains a transparent tunnel, so limited-mode method enforcement does not apply to HTTPS.
|
||||
|
||||
## Library API
|
||||
|
||||
`codex-network-proxy` can be embedded as a library with a thin API:
|
||||
|
||||
```rust
|
||||
use codex_network_proxy::{NetworkProxy, NetworkDecision, NetworkPolicyRequest};
|
||||
|
||||
let proxy = NetworkProxy::builder()
|
||||
.http_addr("127.0.0.1:8080".parse()?)
|
||||
.admin_addr("127.0.0.1:9000".parse()?)
|
||||
.policy_decider(|request: NetworkPolicyRequest| async move {
|
||||
// Example: auto-allow when exec policy already approved a command prefix.
|
||||
if let Some(command) = request.command.as_deref() {
|
||||
if command.starts_with("curl ") {
|
||||
return NetworkDecision::Allow;
|
||||
}
|
||||
}
|
||||
NetworkDecision::Deny {
|
||||
reason: "policy_denied".to_string(),
|
||||
}
|
||||
})
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
let handle = proxy.run().await?;
|
||||
handle.shutdown().await?;
|
||||
```
|
||||
|
||||
When unix socket proxying is enabled, HTTP/admin bind overrides are still clamped to loopback
|
||||
to avoid turning the proxy into a remote bridge to local daemons.
|
||||
|
||||
### Policy hook (exec-policy mapping)
|
||||
|
||||
The proxy exposes a policy hook (`NetworkPolicyDecider`) that can override allowlist-only blocks.
|
||||
It receives `command` and `exec_policy_hint` fields when supplied by the embedding app. This lets
|
||||
core map exec approvals to network access, e.g. if a user already approved `curl *` for a session,
|
||||
the decider can auto-allow network requests originating from that command.
|
||||
|
||||
**Important:** Explicit deny rules still win. The decider only gets a chance to override
|
||||
`not_allowed` (allowlist misses), not `denied` or `not_allowed_local`.
|
||||
|
||||
## Admin API
|
||||
|
||||
The admin API is a small HTTP server intended for debugging and runtime adjustments.
|
||||
|
||||
Endpoints:
|
||||
|
||||
```bash
|
||||
curl -sS http://127.0.0.1:8080/health
|
||||
curl -sS http://127.0.0.1:8080/config
|
||||
curl -sS http://127.0.0.1:8080/patterns
|
||||
curl -sS http://127.0.0.1:8080/blocked
|
||||
|
||||
# Switch modes without restarting:
|
||||
curl -sS -X POST http://127.0.0.1:8080/mode -d '{"mode":"full"}'
|
||||
|
||||
# Force a config reload:
|
||||
curl -sS -X POST http://127.0.0.1:8080/reload
|
||||
```
|
||||
|
||||
## Platform notes
|
||||
|
||||
- Unix socket proxying via the `x-unix-socket` header is **macOS-only**; other platforms will
|
||||
reject unix socket requests.
|
||||
- HTTPS tunneling uses BoringSSL via Rama's `rama-tls-boring`; building the proxy requires a
|
||||
native toolchain and CMake on macOS/Linux/Windows.
|
||||
|
||||
## Security notes (important)
|
||||
|
||||
This section documents the protections implemented by `codex-network-proxy`, and the boundaries of
|
||||
what it can reasonably guarantee.
|
||||
|
||||
- Allowlist-first policy: if `allowed_domains` is empty, requests are blocked until an allowlist is configured.
|
||||
- Deny wins: entries in `denied_domains` always override the allowlist.
|
||||
- Local/private network protection: when `allow_local_binding = false`, the proxy blocks loopback
|
||||
and common private/link-local ranges. Explicit allowlisting of local IP literals (or `localhost`)
|
||||
is required to permit them; hostnames that resolve to local/private IPs are still blocked even if
|
||||
allowlisted (best-effort DNS lookup).
|
||||
- Limited mode enforcement:
|
||||
- only `GET`, `HEAD`, and `OPTIONS` are allowed
|
||||
- HTTPS `CONNECT` remains a tunnel; limited-mode method enforcement does not apply to HTTPS
|
||||
- Listener safety defaults:
|
||||
- the admin API is unauthenticated; non-loopback binds are clamped unless explicitly enabled via
|
||||
`dangerously_allow_non_loopback_admin`
|
||||
- the HTTP proxy listener similarly clamps non-loopback binds unless explicitly enabled via
|
||||
`dangerously_allow_non_loopback_proxy`
|
||||
- when unix socket proxying is enabled, both listeners are forced to loopback to avoid turning the
|
||||
proxy into a remote bridge into local daemons.
|
||||
- `enabled` is enforced at runtime; when false the proxy no-ops and does not bind listeners.
|
||||
Limitations:
|
||||
|
||||
- DNS rebinding is hard to fully prevent without pinning the resolved IP(s) all the way down to the
|
||||
transport layer. If your threat model includes hostile DNS, enforce network egress at a lower
|
||||
layer too (e.g., firewall / VPC / corporate proxy policies).
|
||||
160
codex-rs/network-proxy/src/admin.rs
Normal file
160
codex-rs/network-proxy/src/admin.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::responses::json_response;
|
||||
use crate::responses::text_response;
|
||||
use crate::state::NetworkProxyState;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use rama_core::rt::Executor;
|
||||
use rama_core::service::service_fn;
|
||||
use rama_http::Body;
|
||||
use rama_http::Request;
|
||||
use rama_http::Response;
|
||||
use rama_http::StatusCode;
|
||||
use rama_http_backend::server::HttpServer;
|
||||
use rama_tcp::server::TcpListener;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
|
||||
pub async fn run_admin_api(state: Arc<NetworkProxyState>, addr: SocketAddr) -> Result<()> {
|
||||
// Debug-only admin API (health/config/patterns/blocked + mode/reload). Policy is config-driven
|
||||
// and constraint-enforced; this endpoint should not become a second policy/approval plane.
|
||||
let listener = TcpListener::build()
|
||||
.bind(addr)
|
||||
.await
|
||||
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
|
||||
.map_err(rama_core::error::OpaqueError::from)
|
||||
.map_err(anyhow::Error::from)
|
||||
.with_context(|| format!("bind admin API: {addr}"))?;
|
||||
|
||||
let server_state = state.clone();
|
||||
let server = HttpServer::auto(Executor::new()).service(service_fn(move |req| {
|
||||
let state = server_state.clone();
|
||||
async move { handle_admin_request(state, req).await }
|
||||
}));
|
||||
info!("admin API listening on {addr}");
|
||||
listener.serve(server).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_admin_request(
|
||||
state: Arc<NetworkProxyState>,
|
||||
req: Request,
|
||||
) -> Result<Response, Infallible> {
|
||||
const MODE_BODY_LIMIT: usize = 8 * 1024;
|
||||
|
||||
let method = req.method().clone();
|
||||
let path = req.uri().path().to_string();
|
||||
let response = match (method.as_str(), path.as_str()) {
|
||||
("GET", "/health") => Response::new(Body::from("ok")),
|
||||
("GET", "/config") => match state.current_cfg().await {
|
||||
Ok(cfg) => json_response(&cfg),
|
||||
Err(err) => {
|
||||
error!("failed to load config: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
("GET", "/patterns") => match state.current_patterns().await {
|
||||
Ok((allow, deny)) => json_response(&PatternsResponse {
|
||||
allowed: allow,
|
||||
denied: deny,
|
||||
}),
|
||||
Err(err) => {
|
||||
error!("failed to load patterns: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
("GET", "/blocked") => match state.drain_blocked().await {
|
||||
Ok(blocked) => json_response(&BlockedResponse { blocked }),
|
||||
Err(err) => {
|
||||
error!("failed to read blocked queue: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
("POST", "/mode") => {
|
||||
let mut body = req.into_body();
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
loop {
|
||||
let chunk = match body.chunk().await {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
error!("failed to read mode body: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
|
||||
}
|
||||
};
|
||||
let Some(chunk) = chunk else {
|
||||
break;
|
||||
};
|
||||
|
||||
if buf.len().saturating_add(chunk.len()) > MODE_BODY_LIMIT {
|
||||
return Ok(text_response(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"body too large",
|
||||
));
|
||||
}
|
||||
buf.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if buf.is_empty() {
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "missing body"));
|
||||
}
|
||||
let update: ModeUpdate = match serde_json::from_slice(&buf) {
|
||||
Ok(update) => update,
|
||||
Err(err) => {
|
||||
error!("failed to parse mode update: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid json"));
|
||||
}
|
||||
};
|
||||
match state.set_network_mode(update.mode).await {
|
||||
Ok(()) => json_response(&ModeUpdateResponse {
|
||||
status: "ok",
|
||||
mode: update.mode,
|
||||
}),
|
||||
Err(err) => {
|
||||
error!("mode update failed: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "mode update failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
("POST", "/reload") => match state.force_reload().await {
|
||||
Ok(()) => json_response(&ReloadResponse { status: "reloaded" }),
|
||||
Err(err) => {
|
||||
error!("reload failed: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "reload failed")
|
||||
}
|
||||
},
|
||||
_ => text_response(StatusCode::NOT_FOUND, "not found"),
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModeUpdate {
|
||||
mode: NetworkMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct PatternsResponse {
|
||||
allowed: Vec<String>,
|
||||
denied: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct BlockedResponse<T> {
|
||||
blocked: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ModeUpdateResponse {
|
||||
status: &'static str,
|
||||
mode: NetworkMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ReloadResponse {
|
||||
status: &'static str,
|
||||
}
|
||||
433
codex-rs/network-proxy/src/config.rs
Normal file
433
codex-rs/network-proxy/src/config.rs
Normal file
@@ -0,0 +1,433 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::bail;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::net::IpAddr;
|
||||
use std::net::SocketAddr;
|
||||
use tracing::warn;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct NetworkProxyConfig {
|
||||
#[serde(default)]
|
||||
pub network_proxy: NetworkProxySettings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NetworkProxySettings {
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
#[serde(default = "default_proxy_url")]
|
||||
pub proxy_url: String,
|
||||
#[serde(default = "default_admin_url")]
|
||||
pub admin_url: String,
|
||||
#[serde(default)]
|
||||
pub allow_upstream_proxy: bool,
|
||||
#[serde(default)]
|
||||
pub dangerously_allow_non_loopback_proxy: bool,
|
||||
#[serde(default)]
|
||||
pub dangerously_allow_non_loopback_admin: bool,
|
||||
#[serde(default)]
|
||||
pub mode: NetworkMode,
|
||||
#[serde(default)]
|
||||
pub policy: NetworkPolicy,
|
||||
}
|
||||
|
||||
impl Default for NetworkProxySettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
proxy_url: default_proxy_url(),
|
||||
admin_url: default_admin_url(),
|
||||
allow_upstream_proxy: false,
|
||||
dangerously_allow_non_loopback_proxy: false,
|
||||
dangerously_allow_non_loopback_admin: false,
|
||||
mode: NetworkMode::default(),
|
||||
policy: NetworkPolicy::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct NetworkPolicy {
|
||||
#[serde(default)]
|
||||
pub allowed_domains: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub denied_domains: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub allow_unix_sockets: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub allow_local_binding: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum NetworkMode {
|
||||
/// Limited (read-only) access: only GET/HEAD/OPTIONS are allowed for HTTP. HTTPS CONNECT is
|
||||
/// blocked unless MITM is enabled so the proxy can enforce method policy on inner requests.
|
||||
Limited,
|
||||
/// Full network access: all HTTP methods are allowed, and HTTPS CONNECTs are tunneled without
|
||||
/// MITM interception.
|
||||
#[default]
|
||||
Full,
|
||||
}
|
||||
|
||||
impl NetworkMode {
|
||||
pub fn allows_method(self, method: &str) -> bool {
|
||||
match self {
|
||||
Self::Full => true,
|
||||
Self::Limited => matches!(method, "GET" | "HEAD" | "OPTIONS"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_proxy_url() -> String {
|
||||
"http://127.0.0.1:3128".to_string()
|
||||
}
|
||||
|
||||
fn default_admin_url() -> String {
|
||||
"http://127.0.0.1:8080".to_string()
|
||||
}
|
||||
|
||||
/// Clamp non-loopback bind addresses to loopback unless explicitly allowed.
|
||||
fn clamp_non_loopback(addr: SocketAddr, allow_non_loopback: bool, name: &str) -> SocketAddr {
|
||||
if addr.ip().is_loopback() {
|
||||
return addr;
|
||||
}
|
||||
|
||||
if allow_non_loopback {
|
||||
warn!("DANGEROUS: {name} listening on non-loopback address {addr}");
|
||||
return addr;
|
||||
}
|
||||
|
||||
warn!(
|
||||
"{name} requested non-loopback bind ({addr}); clamping to 127.0.0.1:{port} (set dangerously_allow_non_loopback_proxy or dangerously_allow_non_loopback_admin to override)",
|
||||
port = addr.port()
|
||||
);
|
||||
SocketAddr::from(([127, 0, 0, 1], addr.port()))
|
||||
}
|
||||
|
||||
pub(crate) fn clamp_bind_addrs(
|
||||
http_addr: SocketAddr,
|
||||
admin_addr: SocketAddr,
|
||||
cfg: &NetworkProxySettings,
|
||||
) -> (SocketAddr, SocketAddr) {
|
||||
let http_addr = clamp_non_loopback(
|
||||
http_addr,
|
||||
cfg.dangerously_allow_non_loopback_proxy,
|
||||
"HTTP proxy",
|
||||
);
|
||||
let admin_addr = clamp_non_loopback(
|
||||
admin_addr,
|
||||
cfg.dangerously_allow_non_loopback_admin,
|
||||
"admin API",
|
||||
);
|
||||
if cfg.policy.allow_unix_sockets.is_empty() {
|
||||
return (http_addr, admin_addr);
|
||||
}
|
||||
|
||||
// `x-unix-socket` is intentionally a local escape hatch. If the proxy (or admin API) is
|
||||
// reachable from outside the machine, it can become a remote bridge into local daemons
|
||||
// (e.g. docker.sock). To avoid footguns, enforce loopback binding whenever unix sockets
|
||||
// are enabled.
|
||||
if cfg.dangerously_allow_non_loopback_proxy && !http_addr.ip().is_loopback() {
|
||||
warn!(
|
||||
"unix socket proxying is enabled; ignoring dangerously_allow_non_loopback_proxy and clamping HTTP proxy to loopback"
|
||||
);
|
||||
}
|
||||
if cfg.dangerously_allow_non_loopback_admin && !admin_addr.ip().is_loopback() {
|
||||
warn!(
|
||||
"unix socket proxying is enabled; ignoring dangerously_allow_non_loopback_admin and clamping admin API to loopback"
|
||||
);
|
||||
}
|
||||
(
|
||||
SocketAddr::from(([127, 0, 0, 1], http_addr.port())),
|
||||
SocketAddr::from(([127, 0, 0, 1], admin_addr.port())),
|
||||
)
|
||||
}
|
||||
|
||||
pub struct RuntimeConfig {
|
||||
pub http_addr: SocketAddr,
|
||||
pub admin_addr: SocketAddr,
|
||||
}
|
||||
|
||||
pub fn resolve_runtime(cfg: &NetworkProxyConfig) -> Result<RuntimeConfig> {
|
||||
let http_addr = resolve_addr(&cfg.network_proxy.proxy_url, 3128).with_context(|| {
|
||||
format!(
|
||||
"invalid network_proxy.proxy_url: {}",
|
||||
cfg.network_proxy.proxy_url
|
||||
)
|
||||
})?;
|
||||
let admin_addr = resolve_addr(&cfg.network_proxy.admin_url, 8080).with_context(|| {
|
||||
format!(
|
||||
"invalid network_proxy.admin_url: {}",
|
||||
cfg.network_proxy.admin_url
|
||||
)
|
||||
})?;
|
||||
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg.network_proxy);
|
||||
|
||||
Ok(RuntimeConfig {
|
||||
http_addr,
|
||||
admin_addr,
|
||||
})
|
||||
}
|
||||
|
||||
fn resolve_addr(url: &str, default_port: u16) -> Result<SocketAddr> {
|
||||
let addr_parts = parse_host_port(url, default_port)?;
|
||||
let host = if addr_parts.host.eq_ignore_ascii_case("localhost") {
|
||||
"127.0.0.1".to_string()
|
||||
} else {
|
||||
addr_parts.host
|
||||
};
|
||||
match host.parse::<IpAddr>() {
|
||||
Ok(ip) => Ok(SocketAddr::new(ip, addr_parts.port)),
|
||||
Err(_) => Ok(SocketAddr::from(([127, 0, 0, 1], addr_parts.port))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct SocketAddressParts {
|
||||
host: String,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
fn parse_host_port(url: &str, default_port: u16) -> Result<SocketAddressParts> {
|
||||
let trimmed = url.trim();
|
||||
if trimmed.is_empty() {
|
||||
bail!("missing host in network proxy address: {url}");
|
||||
}
|
||||
|
||||
// Avoid treating unbracketed IPv6 literals like "2001:db8::1" as scheme-prefixed URLs.
|
||||
if matches!(trimmed.parse::<IpAddr>(), Ok(IpAddr::V6(_))) && !trimmed.starts_with('[') {
|
||||
return Ok(SocketAddressParts {
|
||||
host: trimmed.to_string(),
|
||||
port: default_port,
|
||||
});
|
||||
}
|
||||
|
||||
// Prefer the standard URL parser when the input is URL-like. Prefix a scheme when absent so
|
||||
// we still accept loose host:port inputs.
|
||||
let candidate = if trimmed.contains("://") {
|
||||
trimmed.to_string()
|
||||
} else {
|
||||
format!("http://{trimmed}")
|
||||
};
|
||||
if let Ok(parsed) = Url::parse(&candidate)
|
||||
&& let Some(host) = parsed.host_str()
|
||||
{
|
||||
let host = host.trim_matches(|c| c == '[' || c == ']');
|
||||
if host.is_empty() {
|
||||
bail!("missing host in network proxy address: {url}");
|
||||
}
|
||||
return Ok(SocketAddressParts {
|
||||
host: host.to_string(),
|
||||
port: parsed.port().unwrap_or(default_port),
|
||||
});
|
||||
}
|
||||
|
||||
parse_host_port_fallback(trimmed, default_port)
|
||||
}
|
||||
|
||||
fn parse_host_port_fallback(input: &str, default_port: u16) -> Result<SocketAddressParts> {
|
||||
let without_scheme = input
|
||||
.split_once("://")
|
||||
.map(|(_, rest)| rest)
|
||||
.unwrap_or(input);
|
||||
let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
|
||||
let host_port = host_port
|
||||
.rsplit_once('@')
|
||||
.map(|(_, rest)| rest)
|
||||
.unwrap_or(host_port);
|
||||
|
||||
if host_port.starts_with('[')
|
||||
&& let Some(end) = host_port.find(']')
|
||||
{
|
||||
let host = &host_port[1..end];
|
||||
let port = host_port[end + 1..]
|
||||
.strip_prefix(':')
|
||||
.and_then(|port| port.parse::<u16>().ok())
|
||||
.unwrap_or(default_port);
|
||||
if host.is_empty() {
|
||||
bail!("missing host in network proxy address: {input}");
|
||||
}
|
||||
return Ok(SocketAddressParts {
|
||||
host: host.to_string(),
|
||||
port,
|
||||
});
|
||||
}
|
||||
|
||||
// Only treat `host:port` as such when there's a single `:`. This avoids
|
||||
// accidentally interpreting unbracketed IPv6 addresses as `host:port`.
|
||||
if host_port.bytes().filter(|b| *b == b':').count() == 1
|
||||
&& let Some((host, port)) = host_port.rsplit_once(':')
|
||||
&& let Ok(port) = port.parse::<u16>()
|
||||
{
|
||||
if host.is_empty() {
|
||||
bail!("missing host in network proxy address: {input}");
|
||||
}
|
||||
return Ok(SocketAddressParts {
|
||||
host: host.to_string(),
|
||||
port,
|
||||
});
|
||||
}
|
||||
|
||||
if host_port.is_empty() {
|
||||
bail!("missing host in network proxy address: {input}");
|
||||
}
|
||||
Ok(SocketAddressParts {
|
||||
host: host_port.to_string(),
|
||||
port: default_port,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_defaults_for_empty_string() {
|
||||
assert!(parse_host_port("", 1234).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_defaults_for_whitespace() {
|
||||
assert!(parse_host_port(" ", 5555).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_host_port_without_scheme() {
|
||||
assert_eq!(
|
||||
parse_host_port("127.0.0.1:8080", 3128).unwrap(),
|
||||
SocketAddressParts {
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 8080,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_host_port_with_scheme_and_path() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://example.com:8080/some/path", 3128).unwrap(),
|
||||
SocketAddressParts {
|
||||
host: "example.com".to_string(),
|
||||
port: 8080,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_strips_userinfo() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://user:pass@host.example:5555", 3128).unwrap(),
|
||||
SocketAddressParts {
|
||||
host: "host.example".to_string(),
|
||||
port: 5555,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_ipv6_with_brackets() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://[::1]:9999", 3128).unwrap(),
|
||||
SocketAddressParts {
|
||||
host: "::1".to_string(),
|
||||
port: 9999,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_does_not_treat_unbracketed_ipv6_as_host_port() {
|
||||
assert_eq!(
|
||||
parse_host_port("2001:db8::1", 3128).unwrap(),
|
||||
SocketAddressParts {
|
||||
host: "2001:db8::1".to_string(),
|
||||
port: 3128,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_falls_back_to_default_port_when_port_is_invalid() {
|
||||
assert_eq!(
|
||||
parse_host_port("example.com:notaport", 3128).unwrap(),
|
||||
SocketAddressParts {
|
||||
host: "example.com:notaport".to_string(),
|
||||
port: 3128,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_maps_localhost_to_loopback() {
|
||||
assert_eq!(
|
||||
resolve_addr("localhost", 3128).unwrap(),
|
||||
"127.0.0.1:3128".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_parses_ip_literals() {
|
||||
assert_eq!(
|
||||
resolve_addr("1.2.3.4", 80).unwrap(),
|
||||
"1.2.3.4:80".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_parses_ipv6_literals() {
|
||||
assert_eq!(
|
||||
resolve_addr("http://[::1]:8080", 3128).unwrap(),
|
||||
"[::1]:8080".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_falls_back_to_loopback_for_hostnames() {
|
||||
assert_eq!(
|
||||
resolve_addr("http://example.com:5555", 3128).unwrap(),
|
||||
"127.0.0.1:5555".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_bind_addrs_allows_non_loopback_when_enabled() {
|
||||
let cfg = NetworkProxySettings {
|
||||
dangerously_allow_non_loopback_proxy: true,
|
||||
dangerously_allow_non_loopback_admin: true,
|
||||
..Default::default()
|
||||
};
|
||||
let http_addr = "0.0.0.0:3128".parse::<SocketAddr>().unwrap();
|
||||
let admin_addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
|
||||
|
||||
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg);
|
||||
|
||||
assert_eq!(http_addr, "0.0.0.0:3128".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(admin_addr, "0.0.0.0:8080".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_bind_addrs_forces_loopback_when_unix_sockets_enabled() {
|
||||
let cfg = NetworkProxySettings {
|
||||
dangerously_allow_non_loopback_proxy: true,
|
||||
dangerously_allow_non_loopback_admin: true,
|
||||
policy: NetworkPolicy {
|
||||
allow_unix_sockets: vec!["/tmp/docker.sock".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let http_addr = "0.0.0.0:3128".parse::<SocketAddr>().unwrap();
|
||||
let admin_addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
|
||||
|
||||
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg);
|
||||
|
||||
assert_eq!(http_addr, "127.0.0.1:3128".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(admin_addr, "127.0.0.1:8080".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
}
|
||||
636
codex-rs/network-proxy/src/http_proxy.rs
Normal file
636
codex-rs/network-proxy/src/http_proxy.rs
Normal file
@@ -0,0 +1,636 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::network_policy::NetworkDecision;
|
||||
use crate::network_policy::NetworkPolicyDecider;
|
||||
use crate::network_policy::NetworkPolicyRequest;
|
||||
use crate::network_policy::NetworkProtocol;
|
||||
use crate::network_policy::evaluate_host_policy;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_PROXY_DISABLED;
|
||||
use crate::responses::blocked_header_value;
|
||||
use crate::responses::json_response;
|
||||
use crate::runtime::unix_socket_permissions_supported;
|
||||
use crate::state::BlockedRequest;
|
||||
use crate::state::NetworkProxyState;
|
||||
use crate::upstream::UpstreamClient;
|
||||
use crate::upstream::proxy_for_connect;
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use rama_core::Layer;
|
||||
use rama_core::Service;
|
||||
use rama_core::error::BoxError;
|
||||
use rama_core::error::ErrorExt as _;
|
||||
use rama_core::error::OpaqueError;
|
||||
use rama_core::extensions::ExtensionsMut;
|
||||
use rama_core::extensions::ExtensionsRef;
|
||||
use rama_core::layer::AddInputExtensionLayer;
|
||||
use rama_core::rt::Executor;
|
||||
use rama_core::service::service_fn;
|
||||
use rama_http::Body;
|
||||
use rama_http::HeaderValue;
|
||||
use rama_http::Request;
|
||||
use rama_http::Response;
|
||||
use rama_http::StatusCode;
|
||||
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
use rama_http::matcher::MethodMatcher;
|
||||
use rama_http_backend::client::proxy::layer::HttpProxyConnector;
|
||||
use rama_http_backend::server::HttpServer;
|
||||
use rama_http_backend::server::layer::upgrade::UpgradeLayer;
|
||||
use rama_http_backend::server::layer::upgrade::Upgraded;
|
||||
use rama_net::Protocol;
|
||||
use rama_net::address::ProxyAddress;
|
||||
use rama_net::client::ConnectorService;
|
||||
use rama_net::client::EstablishedClientConnection;
|
||||
use rama_net::http::RequestContext;
|
||||
use rama_net::proxy::ProxyRequest;
|
||||
use rama_net::proxy::ProxyTarget;
|
||||
use rama_net::proxy::StreamForwardService;
|
||||
use rama_net::stream::SocketInfo;
|
||||
use rama_tcp::client::Request as TcpRequest;
|
||||
use rama_tcp::client::service::TcpConnector;
|
||||
use rama_tcp::server::TcpListener;
|
||||
use rama_tls_boring::client::TlsConnectorDataBuilder;
|
||||
use rama_tls_boring::client::TlsConnectorLayer;
|
||||
use serde::Serialize;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub async fn run_http_proxy(
|
||||
state: Arc<NetworkProxyState>,
|
||||
addr: SocketAddr,
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
) -> Result<()> {
|
||||
let listener = TcpListener::build()
|
||||
.bind(addr)
|
||||
.await
|
||||
// Rama's `BoxError` is a `Box<dyn Error + Send + Sync>` without an explicit `'static`
|
||||
// lifetime bound, which means it doesn't satisfy `anyhow::Context`'s `StdError` constraint.
|
||||
// Wrap it in Rama's `OpaqueError` so we can preserve the original error as a source and
|
||||
// still use `anyhow` for chaining.
|
||||
.map_err(rama_core::error::OpaqueError::from)
|
||||
.map_err(anyhow::Error::from)
|
||||
.with_context(|| format!("bind HTTP proxy: {addr}"))?;
|
||||
|
||||
let http_service = HttpServer::auto(Executor::new()).service(
|
||||
(
|
||||
UpgradeLayer::new(
|
||||
MethodMatcher::CONNECT,
|
||||
service_fn({
|
||||
let policy_decider = policy_decider.clone();
|
||||
move |req| http_connect_accept(policy_decider.clone(), req)
|
||||
}),
|
||||
service_fn(http_connect_proxy),
|
||||
),
|
||||
RemoveResponseHeaderLayer::hop_by_hop(),
|
||||
RemoveRequestHeaderLayer::hop_by_hop(),
|
||||
)
|
||||
.into_layer(service_fn({
|
||||
let policy_decider = policy_decider.clone();
|
||||
move |req| http_plain_proxy(policy_decider.clone(), req)
|
||||
})),
|
||||
);
|
||||
|
||||
info!("HTTP proxy listening on {addr}");
|
||||
|
||||
listener
|
||||
.serve(AddInputExtensionLayer::new(state).into_layer(http_service))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn http_connect_accept(
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
mut req: Request,
|
||||
) -> Result<(Response, Request), Response> {
|
||||
let app_state = req
|
||||
.extensions()
|
||||
.get::<Arc<NetworkProxyState>>()
|
||||
.cloned()
|
||||
.ok_or_else(|| text_response(StatusCode::INTERNAL_SERVER_ERROR, "missing state"))?;
|
||||
|
||||
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
|
||||
Ok(authority) => authority,
|
||||
Err(err) => {
|
||||
warn!("CONNECT missing authority: {err}");
|
||||
return Err(text_response(StatusCode::BAD_REQUEST, "missing authority"));
|
||||
}
|
||||
};
|
||||
|
||||
let host = normalize_host(&authority.host.to_string());
|
||||
if host.is_empty() {
|
||||
return Err(text_response(StatusCode::BAD_REQUEST, "invalid host"));
|
||||
}
|
||||
|
||||
let client = client_addr(&req);
|
||||
|
||||
let enabled = app_state
|
||||
.enabled()
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to read enabled state", err))?;
|
||||
if !enabled {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("CONNECT blocked; proxy disabled (client={client}, host={host})");
|
||||
return Err(proxy_disabled_response(
|
||||
&app_state,
|
||||
host,
|
||||
client_addr(&req),
|
||||
Some("CONNECT".to_string()),
|
||||
"http-connect",
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::HttpsConnect,
|
||||
host.clone(),
|
||||
authority.port,
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
|
||||
Ok(NetworkDecision::Deny { reason }) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
None,
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("CONNECT blocked (client={client}, host={host}, reason={reason})");
|
||||
return Err(blocked_text(&reason));
|
||||
}
|
||||
Ok(NetworkDecision::Allow) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("CONNECT allowed (client={client}, host={host})");
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate host for CONNECT {host}: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
|
||||
let mode = app_state
|
||||
.network_mode()
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to read network mode", err))?;
|
||||
|
||||
if mode == NetworkMode::Limited {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
REASON_METHOD_NOT_ALLOWED.to_string(),
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
Some(NetworkMode::Limited),
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("CONNECT blocked by method policy (client={client}, host={host}, mode=limited)");
|
||||
return Err(blocked_text(REASON_METHOD_NOT_ALLOWED));
|
||||
}
|
||||
|
||||
req.extensions_mut().insert(ProxyTarget(authority));
|
||||
req.extensions_mut().insert(mode);
|
||||
|
||||
Ok((
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::empty())
|
||||
.unwrap_or_else(|_| Response::new(Body::empty())),
|
||||
req,
|
||||
))
|
||||
}
|
||||
|
||||
async fn http_connect_proxy(upgraded: Upgraded) -> Result<(), Infallible> {
|
||||
if upgraded.extensions().get::<ProxyTarget>().is_none() {
|
||||
warn!("CONNECT missing proxy target");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let allow_upstream_proxy = match upgraded
|
||||
.extensions()
|
||||
.get::<Arc<NetworkProxyState>>()
|
||||
.cloned()
|
||||
{
|
||||
Some(state) => match state.allow_upstream_proxy().await {
|
||||
Ok(allowed) => allowed,
|
||||
Err(err) => {
|
||||
error!("failed to read upstream proxy setting: {err}");
|
||||
false
|
||||
}
|
||||
},
|
||||
None => {
|
||||
error!("missing app state");
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
let proxy = if allow_upstream_proxy {
|
||||
proxy_for_connect()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Err(err) = forward_connect_tunnel(upgraded, proxy).await {
|
||||
warn!("tunnel error: {err}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn forward_connect_tunnel(
|
||||
upgraded: Upgraded,
|
||||
proxy: Option<ProxyAddress>,
|
||||
) -> Result<(), BoxError> {
|
||||
let authority = upgraded
|
||||
.extensions()
|
||||
.get::<ProxyTarget>()
|
||||
.map(|target| target.0.clone())
|
||||
.ok_or_else(|| OpaqueError::from_display("missing forward authority").into_boxed())?;
|
||||
|
||||
let mut extensions = upgraded.extensions().clone();
|
||||
if let Some(proxy) = proxy {
|
||||
extensions.insert(proxy);
|
||||
}
|
||||
|
||||
let req = TcpRequest::new_with_extensions(authority.clone(), extensions)
|
||||
.with_protocol(Protocol::HTTPS);
|
||||
let proxy_connector = HttpProxyConnector::optional(TcpConnector::new());
|
||||
let tls_config = TlsConnectorDataBuilder::new_http_auto().into_shared_builder();
|
||||
let connector = TlsConnectorLayer::tunnel(None)
|
||||
.with_connector_data(tls_config)
|
||||
.into_layer(proxy_connector);
|
||||
let EstablishedClientConnection { conn: target, .. } =
|
||||
connector.connect(req).await.map_err(|err| {
|
||||
OpaqueError::from_boxed(err)
|
||||
.with_context(|| format!("establish CONNECT tunnel to {authority}"))
|
||||
.into_boxed()
|
||||
})?;
|
||||
|
||||
let proxy_req = ProxyRequest {
|
||||
source: upgraded,
|
||||
target,
|
||||
};
|
||||
StreamForwardService::default()
|
||||
.serve(proxy_req)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
OpaqueError::from_boxed(err.into())
|
||||
.with_context(|| format!("forward CONNECT tunnel to {authority}"))
|
||||
.into_boxed()
|
||||
})
|
||||
}
|
||||
|
||||
async fn http_plain_proxy(
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
req: Request,
|
||||
) -> Result<Response, Infallible> {
|
||||
let app_state = match req.extensions().get::<Arc<NetworkProxyState>>().cloned() {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
error!("missing app state");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
let client = client_addr(&req);
|
||||
|
||||
let method_allowed = match app_state
|
||||
.method_allowed(req.method().as_str())
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to evaluate method policy", err))
|
||||
{
|
||||
Ok(allowed) => allowed,
|
||||
Err(resp) => return Ok(resp),
|
||||
};
|
||||
|
||||
// `x-unix-socket` is an escape hatch for talking to local daemons. We keep it tightly scoped:
|
||||
// macOS-only + explicit allowlist, to avoid turning the proxy into a general local capability
|
||||
// escalation mechanism.
|
||||
if let Some(unix_socket_header) = req.headers().get("x-unix-socket") {
|
||||
let socket_path = match unix_socket_header.to_str() {
|
||||
Ok(value) => value.to_string(),
|
||||
Err(_) => {
|
||||
warn!("invalid x-unix-socket header value (non-UTF8)");
|
||||
return Ok(text_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid x-unix-socket header",
|
||||
));
|
||||
}
|
||||
};
|
||||
let enabled = match app_state
|
||||
.enabled()
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to read enabled state", err))
|
||||
{
|
||||
Ok(enabled) => enabled,
|
||||
Err(resp) => return Ok(resp),
|
||||
};
|
||||
if !enabled {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("unix socket blocked; proxy disabled (client={client}, path={socket_path})");
|
||||
return Ok(proxy_disabled_response(
|
||||
&app_state,
|
||||
socket_path,
|
||||
client_addr(&req),
|
||||
Some(req.method().as_str().to_string()),
|
||||
"unix-socket",
|
||||
)
|
||||
.await);
|
||||
}
|
||||
if !method_allowed {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
warn!(
|
||||
"unix socket blocked by method policy (client={client}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Ok(json_blocked("unix-socket", REASON_METHOD_NOT_ALLOWED));
|
||||
}
|
||||
|
||||
if !unix_socket_permissions_supported() {
|
||||
warn!("unix socket proxy unsupported on this platform (path={socket_path})");
|
||||
return Ok(text_response(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"unix sockets unsupported",
|
||||
));
|
||||
}
|
||||
|
||||
return match app_state.is_unix_socket_allowed(&socket_path).await {
|
||||
Ok(true) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("unix socket allowed (client={client}, path={socket_path})");
|
||||
match proxy_via_unix_socket(req, &socket_path).await {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(err) => {
|
||||
warn!("unix socket proxy failed: {err}");
|
||||
Ok(text_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"unix socket proxy failed",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(false) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("unix socket blocked (client={client}, path={socket_path})");
|
||||
Ok(json_blocked("unix-socket", REASON_NOT_ALLOWED))
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("unix socket check failed: {err}");
|
||||
Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
|
||||
Ok(authority) => authority,
|
||||
Err(err) => {
|
||||
warn!("missing host: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
|
||||
}
|
||||
};
|
||||
let host = normalize_host(&authority.host.to_string());
|
||||
let port = authority.port;
|
||||
let enabled = match app_state
|
||||
.enabled()
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to read enabled state", err))
|
||||
{
|
||||
Ok(enabled) => enabled,
|
||||
Err(resp) => return Ok(resp),
|
||||
};
|
||||
if !enabled {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
warn!("request blocked; proxy disabled (client={client}, host={host}, method={method})");
|
||||
return Ok(proxy_disabled_response(
|
||||
&app_state,
|
||||
host,
|
||||
client_addr(&req),
|
||||
Some(req.method().as_str().to_string()),
|
||||
"http",
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Http,
|
||||
host.clone(),
|
||||
port,
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
|
||||
Ok(NetworkDecision::Deny { reason }) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
None,
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("request blocked (client={client}, host={host}, reason={reason})");
|
||||
return Ok(json_blocked(&host, &reason));
|
||||
}
|
||||
Ok(NetworkDecision::Allow) => {}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate host for {host}: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
|
||||
if !method_allowed {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
REASON_METHOD_NOT_ALLOWED.to_string(),
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
Some(NetworkMode::Limited),
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
warn!(
|
||||
"request blocked by method policy (client={client}, host={host}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Ok(json_blocked(&host, REASON_METHOD_NOT_ALLOWED));
|
||||
}
|
||||
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
info!("request allowed (client={client}, host={host}, method={method})");
|
||||
|
||||
let allow_upstream_proxy = match app_state
|
||||
.allow_upstream_proxy()
|
||||
.await
|
||||
.map_err(|err| internal_error("failed to read upstream proxy config", err))
|
||||
{
|
||||
Ok(allow) => allow,
|
||||
Err(resp) => return Ok(resp),
|
||||
};
|
||||
let client = if allow_upstream_proxy {
|
||||
UpstreamClient::from_env_proxy()
|
||||
} else {
|
||||
UpstreamClient::direct()
|
||||
};
|
||||
|
||||
match client.serve(req).await {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(err) => {
|
||||
warn!("upstream request failed: {err}");
|
||||
Ok(text_response(StatusCode::BAD_GATEWAY, "upstream failure"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_via_unix_socket(req: Request, socket_path: &str) -> Result<Response> {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
let client = UpstreamClient::unix_socket(socket_path);
|
||||
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let path = parts
|
||||
.uri
|
||||
.path_and_query()
|
||||
.map(rama_http::uri::PathAndQuery::as_str)
|
||||
.unwrap_or("/");
|
||||
parts.uri = path
|
||||
.parse()
|
||||
.with_context(|| format!("invalid unix socket request path: {path}"))?;
|
||||
parts.headers.remove("x-unix-socket");
|
||||
|
||||
let req = Request::from_parts(parts, body);
|
||||
client.serve(req).await.map_err(anyhow::Error::from)
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _ = req;
|
||||
let _ = socket_path;
|
||||
Err(anyhow::anyhow!("unix sockets not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
fn client_addr<T: ExtensionsRef>(input: &T) -> Option<String> {
|
||||
input
|
||||
.extensions()
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string())
|
||||
}
|
||||
|
||||
fn json_blocked(host: &str, reason: &str) -> Response {
|
||||
let response = BlockedResponse {
|
||||
status: "blocked",
|
||||
host,
|
||||
reason,
|
||||
};
|
||||
let mut resp = json_response(&response);
|
||||
*resp.status_mut() = StatusCode::FORBIDDEN;
|
||||
resp.headers_mut().insert(
|
||||
"x-proxy-error",
|
||||
HeaderValue::from_static(blocked_header_value(reason)),
|
||||
);
|
||||
resp
|
||||
}
|
||||
|
||||
fn blocked_text(reason: &str) -> Response {
|
||||
crate::responses::blocked_text_response(reason)
|
||||
}
|
||||
|
||||
async fn proxy_disabled_response(
|
||||
app_state: &NetworkProxyState,
|
||||
host: String,
|
||||
client: Option<String>,
|
||||
method: Option<String>,
|
||||
protocol: &str,
|
||||
) -> Response {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host,
|
||||
REASON_PROXY_DISABLED.to_string(),
|
||||
client,
|
||||
method,
|
||||
None,
|
||||
protocol.to_string(),
|
||||
))
|
||||
.await;
|
||||
text_response(StatusCode::SERVICE_UNAVAILABLE, "proxy disabled")
|
||||
}
|
||||
|
||||
fn internal_error(context: &str, err: impl std::fmt::Display) -> Response {
|
||||
error!("{context}: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
|
||||
fn text_response(status: StatusCode, body: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct BlockedResponse<'a> {
|
||||
status: &'static str,
|
||||
host: &'a str,
|
||||
reason: &'a str,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::config::NetworkMode;
|
||||
use crate::config::NetworkPolicy;
|
||||
use crate::runtime::network_proxy_state_for_policy;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rama_http::Method;
|
||||
use rama_http::Request;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_connect_accept_blocks_in_limited_mode() {
|
||||
let policy = NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
let state = Arc::new(network_proxy_state_for_policy(policy));
|
||||
state.set_network_mode(NetworkMode::Limited).await.unwrap();
|
||||
|
||||
let mut req = Request::builder()
|
||||
.method(Method::CONNECT)
|
||||
.uri("https://example.com:443")
|
||||
.header("host", "example.com:443")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
req.extensions_mut().insert(state);
|
||||
|
||||
let response = http_connect_accept(None, req).await.unwrap_err();
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
assert_eq!(
|
||||
response.headers().get("x-proxy-error").unwrap(),
|
||||
"blocked-by-method-policy"
|
||||
);
|
||||
}
|
||||
}
|
||||
29
codex-rs/network-proxy/src/lib.rs
Normal file
29
codex-rs/network-proxy/src/lib.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
mod admin;
|
||||
mod config;
|
||||
mod http_proxy;
|
||||
mod network_policy;
|
||||
mod policy;
|
||||
mod proxy;
|
||||
mod reasons;
|
||||
mod responses;
|
||||
mod runtime;
|
||||
mod state;
|
||||
mod upstream;
|
||||
|
||||
use anyhow::Result;
|
||||
pub use network_policy::NetworkDecision;
|
||||
pub use network_policy::NetworkPolicyDecider;
|
||||
pub use network_policy::NetworkPolicyRequest;
|
||||
pub use network_policy::NetworkProtocol;
|
||||
pub use proxy::Args;
|
||||
pub use proxy::NetworkProxy;
|
||||
pub use proxy::NetworkProxyBuilder;
|
||||
pub use proxy::NetworkProxyHandle;
|
||||
|
||||
pub async fn run_main(args: Args) -> Result<()> {
|
||||
let _ = args;
|
||||
let proxy = NetworkProxy::builder().build().await?;
|
||||
proxy.run().await?.wait().await
|
||||
}
|
||||
14
codex-rs/network-proxy/src/main.rs
Normal file
14
codex-rs/network-proxy/src/main.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use codex_network_proxy::Args;
|
||||
use codex_network_proxy::NetworkProxy;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let args = Args::parse();
|
||||
let _ = args;
|
||||
let proxy = NetworkProxy::builder().build().await?;
|
||||
proxy.run().await?.wait().await
|
||||
}
|
||||
234
codex-rs/network-proxy/src/network_policy.rs
Normal file
234
codex-rs/network-proxy/src/network_policy.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use crate::reasons::REASON_POLICY_DENIED;
|
||||
use crate::runtime::HostBlockDecision;
|
||||
use crate::runtime::HostBlockReason;
|
||||
use crate::state::NetworkProxyState;
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum NetworkProtocol {
|
||||
Http,
|
||||
HttpsConnect,
|
||||
Socks5Tcp,
|
||||
Socks5Udp,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NetworkPolicyRequest {
|
||||
pub protocol: NetworkProtocol,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub client_addr: Option<String>,
|
||||
pub method: Option<String>,
|
||||
pub command: Option<String>,
|
||||
pub exec_policy_hint: Option<String>,
|
||||
}
|
||||
|
||||
impl NetworkPolicyRequest {
|
||||
pub fn new(
|
||||
protocol: NetworkProtocol,
|
||||
host: String,
|
||||
port: u16,
|
||||
client_addr: Option<String>,
|
||||
method: Option<String>,
|
||||
command: Option<String>,
|
||||
exec_policy_hint: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
protocol,
|
||||
host,
|
||||
port,
|
||||
client_addr,
|
||||
method,
|
||||
command,
|
||||
exec_policy_hint,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum NetworkDecision {
|
||||
Allow,
|
||||
Deny { reason: String },
|
||||
}
|
||||
|
||||
impl NetworkDecision {
|
||||
pub fn deny(reason: impl Into<String>) -> Self {
|
||||
let reason = reason.into();
|
||||
let reason = if reason.is_empty() {
|
||||
REASON_POLICY_DENIED.to_string()
|
||||
} else {
|
||||
reason
|
||||
};
|
||||
Self::Deny { reason }
|
||||
}
|
||||
}
|
||||
|
||||
/// Decide whether a network request should be allowed.
|
||||
///
|
||||
/// If `command` or `exec_policy_hint` is provided, callers can map exec-policy
|
||||
/// approvals to network access (e.g., allow all requests for commands matching
|
||||
/// approved prefixes like `curl *`).
|
||||
#[async_trait]
|
||||
pub trait NetworkPolicyDecider: Send + Sync + 'static {
|
||||
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<D: NetworkPolicyDecider + ?Sized> NetworkPolicyDecider for Arc<D> {
|
||||
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision {
|
||||
(**self).decide(req).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, Fut> NetworkPolicyDecider for F
|
||||
where
|
||||
F: Fn(NetworkPolicyRequest) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = NetworkDecision> + Send,
|
||||
{
|
||||
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision {
|
||||
(self)(req).await
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn evaluate_host_policy(
|
||||
state: &NetworkProxyState,
|
||||
decider: Option<&Arc<dyn NetworkPolicyDecider>>,
|
||||
request: &NetworkPolicyRequest,
|
||||
) -> Result<NetworkDecision> {
|
||||
match state.host_blocked(&request.host, request.port).await? {
|
||||
HostBlockDecision::Allowed => Ok(NetworkDecision::Allow),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowed) => {
|
||||
if let Some(decider) = decider {
|
||||
Ok(decider.decide(request.clone()).await)
|
||||
} else {
|
||||
Ok(NetworkDecision::deny(HostBlockReason::NotAllowed.as_str()))
|
||||
}
|
||||
}
|
||||
HostBlockDecision::Blocked(reason) => Ok(NetworkDecision::deny(reason.as_str())),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::config::NetworkPolicy;
|
||||
use crate::reasons::REASON_DENIED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
|
||||
use crate::state::network_proxy_state_for_policy;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
#[tokio::test]
|
||||
async fn evaluate_host_policy_invokes_decider_for_not_allowed() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy::default());
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
|
||||
let calls = calls.clone();
|
||||
move |_req| {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
// The default policy denies all; the decider is consulted for not_allowed
|
||||
// requests and can override that decision.
|
||||
async { NetworkDecision::Allow }
|
||||
}
|
||||
});
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Http,
|
||||
"example.com".to_string(),
|
||||
80,
|
||||
None,
|
||||
Some("GET".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let decision = evaluate_host_policy(&state, Some(&decider), &request)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(decision, NetworkDecision::Allow);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn evaluate_host_policy_skips_decider_for_denied() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
denied_domains: vec!["blocked.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
|
||||
let calls = calls.clone();
|
||||
move |_req| {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
async { NetworkDecision::Allow }
|
||||
}
|
||||
});
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Http,
|
||||
"blocked.com".to_string(),
|
||||
80,
|
||||
None,
|
||||
Some("GET".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let decision = evaluate_host_policy(&state, Some(&decider), &request)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
decision,
|
||||
NetworkDecision::Deny {
|
||||
reason: REASON_DENIED.to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn evaluate_host_policy_skips_decider_for_not_allowed_local() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
|
||||
let calls = calls.clone();
|
||||
move |_req| {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
async { NetworkDecision::Allow }
|
||||
}
|
||||
});
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Http,
|
||||
"127.0.0.1".to_string(),
|
||||
80,
|
||||
None,
|
||||
Some("GET".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let decision = evaluate_host_policy(&state, Some(&decider), &request)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
decision,
|
||||
NetworkDecision::Deny {
|
||||
reason: REASON_NOT_ALLOWED_LOCAL.to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
}
|
||||
435
codex-rs/network-proxy/src/policy.rs
Normal file
435
codex-rs/network-proxy/src/policy.rs
Normal file
@@ -0,0 +1,435 @@
|
||||
#[cfg(test)]
|
||||
use crate::config::NetworkMode;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::ensure;
|
||||
use globset::GlobBuilder;
|
||||
use globset::GlobSet;
|
||||
use globset::GlobSetBuilder;
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::net::Ipv6Addr;
|
||||
use url::Host as UrlHost;
|
||||
|
||||
/// A normalized host string for policy evaluation.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct Host(String);
|
||||
|
||||
impl Host {
|
||||
pub fn parse(input: &str) -> Result<Self> {
|
||||
let normalized = normalize_host(input);
|
||||
ensure!(!normalized.is_empty(), "host is empty");
|
||||
Ok(Self(normalized))
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the host is a loopback hostname or IP literal.
|
||||
pub fn is_loopback_host(host: &Host) -> bool {
|
||||
let host = host.as_str();
|
||||
let host = host.split_once('%').map(|(ip, _)| ip).unwrap_or(host);
|
||||
if host == "localhost" {
|
||||
return true;
|
||||
}
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return ip.is_loopback();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn is_non_public_ip(ip: IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ip) => is_non_public_ipv4(ip),
|
||||
IpAddr::V6(ip) => is_non_public_ipv6(ip),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_non_public_ipv4(ip: Ipv4Addr) -> bool {
|
||||
// Use the standard library classification helpers where possible; they encode the intent more
|
||||
// clearly than hand-rolled range checks. Some non-public ranges (e.g., CGNAT and TEST-NET
|
||||
// blocks) are not covered by stable stdlib helpers yet, so we fall back to CIDR checks.
|
||||
ip.is_loopback()
|
||||
|| ip.is_private()
|
||||
|| ip.is_link_local()
|
||||
|| ip.is_unspecified()
|
||||
|| ip.is_multicast()
|
||||
|| ip.is_broadcast()
|
||||
|| ipv4_in_cidr(ip, [0, 0, 0, 0], 8) // "this network" (RFC 1122)
|
||||
|| ipv4_in_cidr(ip, [100, 64, 0, 0], 10) // CGNAT (RFC 6598)
|
||||
|| ipv4_in_cidr(ip, [192, 0, 0, 0], 24) // IETF Protocol Assignments (RFC 6890)
|
||||
|| ipv4_in_cidr(ip, [192, 0, 2, 0], 24) // TEST-NET-1 (RFC 5737)
|
||||
|| ipv4_in_cidr(ip, [198, 18, 0, 0], 15) // Benchmarking (RFC 2544)
|
||||
|| ipv4_in_cidr(ip, [198, 51, 100, 0], 24) // TEST-NET-2 (RFC 5737)
|
||||
|| ipv4_in_cidr(ip, [203, 0, 113, 0], 24) // TEST-NET-3 (RFC 5737)
|
||||
|| ipv4_in_cidr(ip, [240, 0, 0, 0], 4) // Reserved (RFC 6890)
|
||||
}
|
||||
|
||||
fn ipv4_in_cidr(ip: Ipv4Addr, base: [u8; 4], prefix: u8) -> bool {
|
||||
let ip = u32::from(ip);
|
||||
let base = u32::from(Ipv4Addr::from(base));
|
||||
let mask = if prefix == 0 {
|
||||
0
|
||||
} else {
|
||||
u32::MAX << (32 - prefix)
|
||||
};
|
||||
(ip & mask) == (base & mask)
|
||||
}
|
||||
|
||||
fn is_non_public_ipv6(ip: Ipv6Addr) -> bool {
|
||||
if let Some(v4) = ip.to_ipv4() {
|
||||
return is_non_public_ipv4(v4) || ip.is_loopback();
|
||||
}
|
||||
// Treat anything that isn't globally routable as "local" for SSRF prevention. In particular:
|
||||
// - `::1` loopback
|
||||
// - `fc00::/7` unique-local (RFC 4193)
|
||||
// - `fe80::/10` link-local
|
||||
// - `::` unspecified
|
||||
// - multicast ranges
|
||||
ip.is_loopback()
|
||||
|| ip.is_unspecified()
|
||||
|| ip.is_multicast()
|
||||
|| ip.is_unique_local()
|
||||
|| ip.is_unicast_link_local()
|
||||
}
|
||||
|
||||
/// Normalize host fragments for policy matching (trim whitespace, strip ports/brackets, lowercase).
|
||||
pub fn normalize_host(host: &str) -> String {
|
||||
let host = host.trim();
|
||||
if host.starts_with('[')
|
||||
&& let Some(end) = host.find(']')
|
||||
{
|
||||
return normalize_dns_host(&host[1..end]);
|
||||
}
|
||||
|
||||
// The proxy stack should typically hand us a host without a port, but be
|
||||
// defensive and strip `:port` when there is exactly one `:`.
|
||||
if host.bytes().filter(|b| *b == b':').count() == 1 {
|
||||
let host = host.split(':').next().unwrap_or_default();
|
||||
return normalize_dns_host(host);
|
||||
}
|
||||
|
||||
// Avoid mangling unbracketed IPv6 literals, but strip trailing dots so fully qualified domain
|
||||
// names are treated the same as their dotless variants.
|
||||
normalize_dns_host(host)
|
||||
}
|
||||
|
||||
fn normalize_dns_host(host: &str) -> String {
|
||||
let host = host.to_ascii_lowercase();
|
||||
host.trim_end_matches('.').to_string()
|
||||
}
|
||||
|
||||
fn normalize_pattern(pattern: &str) -> String {
|
||||
let pattern = pattern.trim();
|
||||
if pattern == "*" {
|
||||
return "*".to_string();
|
||||
}
|
||||
|
||||
let (prefix, remainder) = if let Some(domain) = pattern.strip_prefix("**.") {
|
||||
("**.", domain)
|
||||
} else if let Some(domain) = pattern.strip_prefix("*.") {
|
||||
("*.", domain)
|
||||
} else {
|
||||
("", pattern)
|
||||
};
|
||||
|
||||
let remainder = normalize_host(remainder);
|
||||
if prefix.is_empty() {
|
||||
remainder
|
||||
} else {
|
||||
format!("{prefix}{remainder}")
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn compile_globset(patterns: &[String]) -> Result<GlobSet> {
|
||||
let mut builder = GlobSetBuilder::new();
|
||||
let mut seen = HashSet::new();
|
||||
for pattern in patterns {
|
||||
let pattern = normalize_pattern(pattern);
|
||||
// Supported domain patterns:
|
||||
// - "example.com": match the exact host
|
||||
// - "*.example.com": match any subdomain (not the apex)
|
||||
// - "**.example.com": match the apex and any subdomain
|
||||
// - "*": match any host
|
||||
for candidate in expand_domain_pattern(&pattern) {
|
||||
if !seen.insert(candidate.clone()) {
|
||||
continue;
|
||||
}
|
||||
let glob = GlobBuilder::new(&candidate)
|
||||
.case_insensitive(true)
|
||||
.build()
|
||||
.with_context(|| format!("invalid glob pattern: {candidate}"))?;
|
||||
builder.add(glob);
|
||||
}
|
||||
}
|
||||
Ok(builder.build()?)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum DomainPattern {
|
||||
Any,
|
||||
ApexAndSubdomains(String),
|
||||
SubdomainsOnly(String),
|
||||
Exact(String),
|
||||
}
|
||||
|
||||
impl DomainPattern {
|
||||
/// Parse a policy pattern for constraint comparisons.
|
||||
///
|
||||
/// Validation of glob syntax happens when building the globset; here we only
|
||||
/// decode the wildcard prefixes to keep constraint checks lightweight.
|
||||
pub(crate) fn parse(input: &str) -> Self {
|
||||
let input = input.trim();
|
||||
if input.is_empty() {
|
||||
return Self::Exact(String::new());
|
||||
}
|
||||
if input == "*" {
|
||||
Self::Any
|
||||
} else if let Some(domain) = input.strip_prefix("**.") {
|
||||
Self::parse_domain(domain, Self::ApexAndSubdomains)
|
||||
} else if let Some(domain) = input.strip_prefix("*.") {
|
||||
Self::parse_domain(domain, Self::SubdomainsOnly)
|
||||
} else {
|
||||
Self::Exact(input.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a policy pattern for constraint comparisons, validating domain parts with `url`.
|
||||
pub(crate) fn parse_for_constraints(input: &str) -> Self {
|
||||
let input = input.trim();
|
||||
if input.is_empty() {
|
||||
return Self::Exact(String::new());
|
||||
}
|
||||
if input == "*" {
|
||||
return Self::Any;
|
||||
}
|
||||
if let Some(domain) = input.strip_prefix("**.") {
|
||||
return Self::ApexAndSubdomains(parse_domain_for_constraints(domain));
|
||||
}
|
||||
if let Some(domain) = input.strip_prefix("*.") {
|
||||
return Self::SubdomainsOnly(parse_domain_for_constraints(domain));
|
||||
}
|
||||
Self::Exact(parse_domain_for_constraints(input))
|
||||
}
|
||||
|
||||
fn parse_domain(domain: &str, build: impl FnOnce(String) -> Self) -> Self {
|
||||
let domain = domain.trim();
|
||||
if domain.is_empty() {
|
||||
return Self::Exact(String::new());
|
||||
}
|
||||
build(domain.to_string())
|
||||
}
|
||||
|
||||
pub(crate) fn allows(&self, candidate: &DomainPattern) -> bool {
|
||||
match self {
|
||||
DomainPattern::Any => true,
|
||||
DomainPattern::Exact(domain) => match candidate {
|
||||
DomainPattern::Exact(candidate) => domain_eq(candidate, domain),
|
||||
_ => false,
|
||||
},
|
||||
DomainPattern::SubdomainsOnly(domain) => match candidate {
|
||||
DomainPattern::Any => false,
|
||||
DomainPattern::Exact(candidate) => is_strict_subdomain(candidate, domain),
|
||||
DomainPattern::SubdomainsOnly(candidate) => {
|
||||
is_subdomain_or_equal(candidate, domain)
|
||||
}
|
||||
DomainPattern::ApexAndSubdomains(candidate) => {
|
||||
is_strict_subdomain(candidate, domain)
|
||||
}
|
||||
},
|
||||
DomainPattern::ApexAndSubdomains(domain) => match candidate {
|
||||
DomainPattern::Any => false,
|
||||
DomainPattern::Exact(candidate) => is_subdomain_or_equal(candidate, domain),
|
||||
DomainPattern::SubdomainsOnly(candidate) => {
|
||||
is_subdomain_or_equal(candidate, domain)
|
||||
}
|
||||
DomainPattern::ApexAndSubdomains(candidate) => {
|
||||
is_subdomain_or_equal(candidate, domain)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_domain_for_constraints(domain: &str) -> String {
|
||||
let domain = domain.trim().trim_end_matches('.');
|
||||
if domain.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
let host = if domain.starts_with('[') && domain.ends_with(']') {
|
||||
&domain[1..domain.len().saturating_sub(1)]
|
||||
} else {
|
||||
domain
|
||||
};
|
||||
if host.contains('*') || host.contains('?') || host.contains('%') {
|
||||
return domain.to_string();
|
||||
}
|
||||
match UrlHost::parse(host) {
|
||||
Ok(host) => host.to_string(),
|
||||
Err(_) => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_domain_pattern(pattern: &str) -> Vec<String> {
|
||||
match DomainPattern::parse(pattern) {
|
||||
DomainPattern::Any => vec![pattern.to_string()],
|
||||
DomainPattern::Exact(domain) => vec![domain],
|
||||
DomainPattern::SubdomainsOnly(domain) => {
|
||||
vec![format!("?*.{domain}")]
|
||||
}
|
||||
DomainPattern::ApexAndSubdomains(domain) => {
|
||||
vec![domain.clone(), format!("?*.{domain}")]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_domain(domain: &str) -> String {
|
||||
domain.trim_end_matches('.').to_ascii_lowercase()
|
||||
}
|
||||
|
||||
fn domain_eq(left: &str, right: &str) -> bool {
|
||||
normalize_domain(left) == normalize_domain(right)
|
||||
}
|
||||
|
||||
fn is_subdomain_or_equal(child: &str, parent: &str) -> bool {
|
||||
let child = normalize_domain(child);
|
||||
let parent = normalize_domain(parent);
|
||||
if child == parent {
|
||||
return true;
|
||||
}
|
||||
child.ends_with(&format!(".{parent}"))
|
||||
}
|
||||
|
||||
fn is_strict_subdomain(child: &str, parent: &str) -> bool {
|
||||
let child = normalize_domain(child);
|
||||
let parent = normalize_domain(parent);
|
||||
child != parent && child.ends_with(&format!(".{parent}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn method_allowed_full_allows_everything() {
|
||||
assert!(NetworkMode::Full.allows_method("GET"));
|
||||
assert!(NetworkMode::Full.allows_method("POST"));
|
||||
assert!(NetworkMode::Full.allows_method("CONNECT"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn method_allowed_limited_allows_only_safe_methods() {
|
||||
assert!(NetworkMode::Limited.allows_method("GET"));
|
||||
assert!(NetworkMode::Limited.allows_method("HEAD"));
|
||||
assert!(NetworkMode::Limited.allows_method("OPTIONS"));
|
||||
assert!(!NetworkMode::Limited.allows_method("POST"));
|
||||
assert!(!NetworkMode::Limited.allows_method("CONNECT"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_normalizes_trailing_dots() {
|
||||
let set = compile_globset(&["Example.COM.".to_string()]).unwrap();
|
||||
|
||||
assert_eq!(true, set.is_match("example.com"));
|
||||
assert_eq!(false, set.is_match("api.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_normalizes_wildcards() {
|
||||
let set = compile_globset(&["*.Example.COM.".to_string()]).unwrap();
|
||||
|
||||
assert_eq!(true, set.is_match("api.example.com"));
|
||||
assert_eq!(false, set.is_match("example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_normalizes_apex_and_subdomains() {
|
||||
let set = compile_globset(&["**.Example.COM.".to_string()]).unwrap();
|
||||
|
||||
assert_eq!(true, set.is_match("example.com"));
|
||||
assert_eq!(true, set.is_match("api.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_normalizes_bracketed_ipv6_literals() {
|
||||
let set = compile_globset(&["[::1]".to_string()]).unwrap();
|
||||
|
||||
assert_eq!(true, set.is_match("::1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_loopback_host_handles_localhost_variants() {
|
||||
assert!(is_loopback_host(&Host::parse("localhost").unwrap()));
|
||||
assert!(is_loopback_host(&Host::parse("localhost.").unwrap()));
|
||||
assert!(is_loopback_host(&Host::parse("LOCALHOST").unwrap()));
|
||||
assert!(!is_loopback_host(&Host::parse("notlocalhost").unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_loopback_host_handles_ip_literals() {
|
||||
assert!(is_loopback_host(&Host::parse("127.0.0.1").unwrap()));
|
||||
assert!(is_loopback_host(&Host::parse("::1").unwrap()));
|
||||
assert!(!is_loopback_host(&Host::parse("1.2.3.4").unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_non_public_ip_rejects_private_and_loopback_ranges() {
|
||||
assert!(is_non_public_ip("127.0.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("10.0.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("192.168.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("100.64.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("192.0.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("192.0.2.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("198.18.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("198.51.100.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("203.0.113.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("240.0.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("0.1.2.3".parse().unwrap()));
|
||||
assert!(!is_non_public_ip("8.8.8.8".parse().unwrap()));
|
||||
|
||||
assert!(is_non_public_ip("::ffff:127.0.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("::ffff:10.0.0.1".parse().unwrap()));
|
||||
assert!(!is_non_public_ip("::ffff:8.8.8.8".parse().unwrap()));
|
||||
|
||||
assert!(is_non_public_ip("::1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("fe80::1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("fc00::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_lowercases_and_trims() {
|
||||
assert_eq!(normalize_host(" ExAmPlE.CoM "), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_port_for_host_port() {
|
||||
assert_eq!(normalize_host("example.com:1234"), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_preserves_unbracketed_ipv6() {
|
||||
assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_trailing_dot() {
|
||||
assert_eq!(normalize_host("example.com."), "example.com");
|
||||
assert_eq!(normalize_host("ExAmPlE.CoM."), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_trailing_dot_with_port() {
|
||||
assert_eq!(normalize_host("example.com.:443"), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_brackets_for_ipv6() {
|
||||
assert_eq!(normalize_host("[::1]"), "::1");
|
||||
assert_eq!(normalize_host("[::1]:443"), "::1");
|
||||
}
|
||||
}
|
||||
176
codex-rs/network-proxy/src/proxy.rs
Normal file
176
codex-rs/network-proxy/src/proxy.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use crate::admin;
|
||||
use crate::config;
|
||||
use crate::http_proxy;
|
||||
use crate::network_policy::NetworkPolicyDecider;
|
||||
use crate::runtime::unix_socket_permissions_supported;
|
||||
use crate::state::NetworkProxyState;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug, Clone, Parser)]
|
||||
#[command(name = "codex-network-proxy", about = "Codex network sandbox proxy")]
|
||||
pub struct Args {}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct NetworkProxyBuilder {
|
||||
state: Option<Arc<NetworkProxyState>>,
|
||||
http_addr: Option<SocketAddr>,
|
||||
admin_addr: Option<SocketAddr>,
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
}
|
||||
|
||||
impl NetworkProxyBuilder {
|
||||
pub fn state(mut self, state: Arc<NetworkProxyState>) -> Self {
|
||||
self.state = Some(state);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn http_addr(mut self, addr: SocketAddr) -> Self {
|
||||
self.http_addr = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn admin_addr(mut self, addr: SocketAddr) -> Self {
|
||||
self.admin_addr = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn policy_decider<D>(mut self, decider: D) -> Self
|
||||
where
|
||||
D: NetworkPolicyDecider,
|
||||
{
|
||||
self.policy_decider = Some(Arc::new(decider));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn policy_decider_arc(mut self, decider: Arc<dyn NetworkPolicyDecider>) -> Self {
|
||||
self.policy_decider = Some(decider);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn build(self) -> Result<NetworkProxy> {
|
||||
let state = match self.state {
|
||||
Some(state) => state,
|
||||
None => Arc::new(NetworkProxyState::new().await?),
|
||||
};
|
||||
let current_cfg = state.current_cfg().await?;
|
||||
let runtime = config::resolve_runtime(¤t_cfg)?;
|
||||
// Reapply bind clamping for caller overrides so unix-socket proxying stays loopback-only.
|
||||
let (http_addr, admin_addr) = config::clamp_bind_addrs(
|
||||
self.http_addr.unwrap_or(runtime.http_addr),
|
||||
self.admin_addr.unwrap_or(runtime.admin_addr),
|
||||
¤t_cfg.network_proxy,
|
||||
);
|
||||
|
||||
Ok(NetworkProxy {
|
||||
state,
|
||||
http_addr,
|
||||
admin_addr,
|
||||
policy_decider: self.policy_decider,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NetworkProxy {
|
||||
state: Arc<NetworkProxyState>,
|
||||
http_addr: SocketAddr,
|
||||
admin_addr: SocketAddr,
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
}
|
||||
|
||||
impl NetworkProxy {
|
||||
pub fn builder() -> NetworkProxyBuilder {
|
||||
NetworkProxyBuilder::default()
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<NetworkProxyHandle> {
|
||||
let current_cfg = self.state.current_cfg().await?;
|
||||
if !current_cfg.network_proxy.enabled {
|
||||
warn!("network_proxy.enabled is false; skipping proxy listeners");
|
||||
return Ok(NetworkProxyHandle::noop());
|
||||
}
|
||||
|
||||
if !unix_socket_permissions_supported() {
|
||||
warn!("allowUnixSockets is macOS-only; requests will be rejected on this platform");
|
||||
}
|
||||
|
||||
let http_task = tokio::spawn(http_proxy::run_http_proxy(
|
||||
self.state.clone(),
|
||||
self.http_addr,
|
||||
self.policy_decider.clone(),
|
||||
));
|
||||
let admin_task = tokio::spawn(admin::run_admin_api(self.state.clone(), self.admin_addr));
|
||||
|
||||
Ok(NetworkProxyHandle {
|
||||
http_task: Some(http_task),
|
||||
admin_task: Some(admin_task),
|
||||
completed: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NetworkProxyHandle {
|
||||
http_task: Option<JoinHandle<Result<()>>>,
|
||||
admin_task: Option<JoinHandle<Result<()>>>,
|
||||
completed: bool,
|
||||
}
|
||||
|
||||
impl NetworkProxyHandle {
|
||||
fn noop() -> Self {
|
||||
Self {
|
||||
http_task: Some(tokio::spawn(async { Ok(()) })),
|
||||
admin_task: Some(tokio::spawn(async { Ok(()) })),
|
||||
completed: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wait(mut self) -> Result<()> {
|
||||
let http_task = self.http_task.take().context("missing http proxy task")?;
|
||||
let admin_task = self.admin_task.take().context("missing admin proxy task")?;
|
||||
let http_result = http_task.await;
|
||||
let admin_result = admin_task.await;
|
||||
self.completed = true;
|
||||
http_result??;
|
||||
admin_result??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn shutdown(mut self) -> Result<()> {
|
||||
abort_tasks(self.http_task.take(), self.admin_task.take()).await;
|
||||
self.completed = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn abort_tasks(
|
||||
http_task: Option<JoinHandle<Result<()>>>,
|
||||
admin_task: Option<JoinHandle<Result<()>>>,
|
||||
) {
|
||||
if let Some(http_task) = http_task {
|
||||
http_task.abort();
|
||||
let _ = http_task.await;
|
||||
}
|
||||
if let Some(admin_task) = admin_task {
|
||||
admin_task.abort();
|
||||
let _ = admin_task.await;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for NetworkProxyHandle {
|
||||
fn drop(&mut self) {
|
||||
if self.completed {
|
||||
return;
|
||||
}
|
||||
let http_task = self.http_task.take();
|
||||
let admin_task = self.admin_task.take();
|
||||
tokio::spawn(async move {
|
||||
abort_tasks(http_task, admin_task).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
6
codex-rs/network-proxy/src/reasons.rs
Normal file
6
codex-rs/network-proxy/src/reasons.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub(crate) const REASON_DENIED: &str = "denied";
|
||||
pub(crate) const REASON_METHOD_NOT_ALLOWED: &str = "method_not_allowed";
|
||||
pub(crate) const REASON_NOT_ALLOWED: &str = "not_allowed";
|
||||
pub(crate) const REASON_NOT_ALLOWED_LOCAL: &str = "not_allowed_local";
|
||||
pub(crate) const REASON_POLICY_DENIED: &str = "policy_denied";
|
||||
pub(crate) const REASON_PROXY_DISABLED: &str = "proxy_disabled";
|
||||
67
codex-rs/network-proxy/src/responses.rs
Normal file
67
codex-rs/network-proxy/src/responses.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use crate::reasons::REASON_DENIED;
|
||||
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
|
||||
use rama_http::Body;
|
||||
use rama_http::Response;
|
||||
use rama_http::StatusCode;
|
||||
use serde::Serialize;
|
||||
use tracing::error;
|
||||
|
||||
pub fn text_response(status: StatusCode, body: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
|
||||
}
|
||||
|
||||
pub fn json_response<T: Serialize>(value: &T) -> Response {
|
||||
let body = match serde_json::to_string(value) {
|
||||
Ok(body) => body,
|
||||
Err(err) => {
|
||||
error!("failed to serialize JSON response: {err}");
|
||||
"{}".to_string()
|
||||
}
|
||||
};
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(body))
|
||||
.unwrap_or_else(|err| {
|
||||
error!("failed to build JSON response: {err}");
|
||||
Response::new(Body::from("{}"))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn blocked_header_value(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
REASON_NOT_ALLOWED | REASON_NOT_ALLOWED_LOCAL => "blocked-by-allowlist",
|
||||
REASON_DENIED => "blocked-by-denylist",
|
||||
REASON_METHOD_NOT_ALLOWED => "blocked-by-method-policy",
|
||||
_ => "blocked-by-policy",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn blocked_message(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
REASON_NOT_ALLOWED => "Codex blocked this request: domain not in allowlist.",
|
||||
REASON_NOT_ALLOWED_LOCAL => {
|
||||
"Codex blocked this request: local/private addresses not allowed."
|
||||
}
|
||||
REASON_DENIED => "Codex blocked this request: domain denied by policy.",
|
||||
REASON_METHOD_NOT_ALLOWED => {
|
||||
"Codex blocked this request: method not allowed in limited mode."
|
||||
}
|
||||
_ => "Codex blocked this request by network policy.",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn blocked_text_response(reason: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "text/plain")
|
||||
.header("x-proxy-error", blocked_header_value(reason))
|
||||
.body(Body::from(blocked_message(reason)))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
|
||||
}
|
||||
984
codex-rs/network-proxy/src/runtime.rs
Normal file
984
codex-rs/network-proxy/src/runtime.rs
Normal file
@@ -0,0 +1,984 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::config::NetworkProxyConfig;
|
||||
use crate::policy::Host;
|
||||
use crate::policy::is_loopback_host;
|
||||
use crate::policy::is_non_public_ip;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::reasons::REASON_DENIED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED;
|
||||
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
|
||||
use crate::state::NetworkProxyConstraints;
|
||||
use crate::state::build_config_state;
|
||||
use crate::state::validate_policy_against_constraints;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use globset::GlobSet;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
use std::net::IpAddr;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use time::OffsetDateTime;
|
||||
use tokio::net::lookup_host;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::timeout;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
const MAX_BLOCKED_EVENTS: usize = 200;
|
||||
const DNS_LOOKUP_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum HostBlockReason {
|
||||
Denied,
|
||||
NotAllowed,
|
||||
NotAllowedLocal,
|
||||
}
|
||||
|
||||
impl HostBlockReason {
|
||||
pub const fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Denied => REASON_DENIED,
|
||||
Self::NotAllowed => REASON_NOT_ALLOWED,
|
||||
Self::NotAllowedLocal => REASON_NOT_ALLOWED_LOCAL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HostBlockReason {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum HostBlockDecision {
|
||||
Allowed,
|
||||
Blocked(HostBlockReason),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct BlockedRequest {
|
||||
pub host: String,
|
||||
pub reason: String,
|
||||
pub client: Option<String>,
|
||||
pub method: Option<String>,
|
||||
pub mode: Option<NetworkMode>,
|
||||
pub protocol: String,
|
||||
pub timestamp: i64,
|
||||
}
|
||||
|
||||
impl BlockedRequest {
|
||||
pub fn new(
|
||||
host: String,
|
||||
reason: String,
|
||||
client: Option<String>,
|
||||
method: Option<String>,
|
||||
mode: Option<NetworkMode>,
|
||||
protocol: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
host,
|
||||
reason,
|
||||
client,
|
||||
method,
|
||||
mode,
|
||||
protocol,
|
||||
timestamp: unix_timestamp(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ConfigState {
|
||||
pub(crate) config: NetworkProxyConfig,
|
||||
pub(crate) allow_set: GlobSet,
|
||||
pub(crate) deny_set: GlobSet,
|
||||
pub(crate) constraints: NetworkProxyConstraints,
|
||||
pub(crate) layer_mtimes: Vec<LayerMtime>,
|
||||
pub(crate) cfg_path: PathBuf,
|
||||
pub(crate) blocked: VecDeque<BlockedRequest>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct LayerMtime {
|
||||
pub(crate) path: PathBuf,
|
||||
pub(crate) mtime: Option<SystemTime>,
|
||||
}
|
||||
|
||||
impl LayerMtime {
|
||||
pub(crate) fn new(path: PathBuf) -> Self {
|
||||
let mtime = path.metadata().and_then(|m| m.modified()).ok();
|
||||
Self { path, mtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NetworkProxyState {
|
||||
state: Arc<RwLock<ConfigState>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NetworkProxyState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Avoid logging internal state (config contents, derived globsets, etc.) which can be noisy
|
||||
// and may contain sensitive paths.
|
||||
f.debug_struct("NetworkProxyState").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkProxyState {
|
||||
pub async fn new() -> Result<Self> {
|
||||
let cfg_state = build_config_state().await?;
|
||||
Ok(Self {
|
||||
state: Arc::new(RwLock::new(cfg_state)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn current_cfg(&self) -> Result<NetworkProxyConfig> {
|
||||
// Callers treat `NetworkProxyState` as a live view of policy. We reload-on-demand so edits to
|
||||
// `config.toml` (including Codex-managed writes) take effect without a restart.
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.clone())
|
||||
}
|
||||
|
||||
pub async fn current_patterns(&self) -> Result<(Vec<String>, Vec<String>)> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok((
|
||||
guard.config.network_proxy.policy.allowed_domains.clone(),
|
||||
guard.config.network_proxy.policy.denied_domains.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn enabled(&self) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.network_proxy.enabled)
|
||||
}
|
||||
|
||||
pub async fn force_reload(&self) -> Result<()> {
|
||||
let (previous_cfg, cfg_path) = {
|
||||
let guard = self.state.read().await;
|
||||
(guard.config.clone(), guard.cfg_path.clone())
|
||||
};
|
||||
|
||||
match build_config_state().await {
|
||||
Ok(mut new_state) => {
|
||||
// Policy changes are operationally sensitive; logging diffs makes changes traceable
|
||||
// without needing to dump full config blobs (which can include unrelated settings).
|
||||
log_policy_changes(&previous_cfg, &new_state.config);
|
||||
let mut guard = self.state.write().await;
|
||||
new_state.blocked = guard.blocked.clone();
|
||||
*guard = new_state;
|
||||
let path = guard.cfg_path.display();
|
||||
info!("reloaded config from {path}");
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let path = cfg_path.display();
|
||||
warn!("failed to reload config from {path}: {err}; keeping previous config");
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn host_blocked(&self, host: &str, port: u16) -> Result<HostBlockDecision> {
|
||||
self.reload_if_needed().await?;
|
||||
let host = match Host::parse(host) {
|
||||
Ok(host) => host,
|
||||
Err(_) => return Ok(HostBlockDecision::Blocked(HostBlockReason::NotAllowed)),
|
||||
};
|
||||
let (deny_set, allow_set, allow_local_binding, allowed_domains_empty, allowed_domains) = {
|
||||
let guard = self.state.read().await;
|
||||
(
|
||||
guard.deny_set.clone(),
|
||||
guard.allow_set.clone(),
|
||||
guard.config.network_proxy.policy.allow_local_binding,
|
||||
guard.config.network_proxy.policy.allowed_domains.is_empty(),
|
||||
guard.config.network_proxy.policy.allowed_domains.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
let host_str = host.as_str();
|
||||
|
||||
// Decision order matters:
|
||||
// 1) explicit deny always wins
|
||||
// 2) local/private networking is opt-in (defense-in-depth)
|
||||
// 3) allowlist is enforced when configured
|
||||
if deny_set.is_match(host_str) {
|
||||
return Ok(HostBlockDecision::Blocked(HostBlockReason::Denied));
|
||||
}
|
||||
|
||||
let is_allowlisted = allow_set.is_match(host_str);
|
||||
if !allow_local_binding {
|
||||
// If the intent is "prevent access to local/internal networks", we must not rely solely
|
||||
// on string checks like `localhost` / `127.0.0.1`. Attackers can use DNS rebinding or
|
||||
// public suffix services that map hostnames onto private IPs.
|
||||
//
|
||||
// We therefore do a best-effort DNS + IP classification check before allowing the
|
||||
// request. Explicit local/loopback literals are allowed only when explicitly
|
||||
// allowlisted; hostnames that resolve to local/private IPs are blocked even if
|
||||
// allowlisted.
|
||||
let local_literal = {
|
||||
let host_no_scope = host_str
|
||||
.split_once('%')
|
||||
.map(|(ip, _)| ip)
|
||||
.unwrap_or(host_str);
|
||||
if is_loopback_host(&host) {
|
||||
true
|
||||
} else if let Ok(ip) = host_no_scope.parse::<IpAddr>() {
|
||||
is_non_public_ip(ip)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if local_literal {
|
||||
if !is_explicit_local_allowlisted(&allowed_domains, &host) {
|
||||
return Ok(HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal));
|
||||
}
|
||||
} else if host_resolves_to_non_public_ip(host_str, port).await {
|
||||
return Ok(HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal));
|
||||
}
|
||||
}
|
||||
|
||||
if allowed_domains_empty || !is_allowlisted {
|
||||
Ok(HostBlockDecision::Blocked(HostBlockReason::NotAllowed))
|
||||
} else {
|
||||
Ok(HostBlockDecision::Allowed)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn record_blocked(&self, entry: BlockedRequest) -> Result<()> {
|
||||
self.reload_if_needed().await?;
|
||||
let mut guard = self.state.write().await;
|
||||
guard.blocked.push_back(entry);
|
||||
while guard.blocked.len() > MAX_BLOCKED_EVENTS {
|
||||
guard.blocked.pop_front();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Drain and return the buffered blocked-request entries in FIFO order.
|
||||
pub async fn drain_blocked(&self) -> Result<Vec<BlockedRequest>> {
|
||||
self.reload_if_needed().await?;
|
||||
let blocked = {
|
||||
let mut guard = self.state.write().await;
|
||||
std::mem::take(&mut guard.blocked)
|
||||
};
|
||||
Ok(blocked.into_iter().collect())
|
||||
}
|
||||
|
||||
pub async fn is_unix_socket_allowed(&self, path: &str) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
if !unix_socket_permissions_supported() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// We only support absolute unix socket paths (a relative path would be ambiguous with
|
||||
// respect to the proxy process's CWD and can lead to confusing allowlist behavior).
|
||||
let requested_path = Path::new(path);
|
||||
if !requested_path.is_absolute() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let guard = self.state.read().await;
|
||||
// Normalize the path while keeping the absolute-path requirement explicit.
|
||||
let requested_abs = match AbsolutePathBuf::from_absolute_path(requested_path) {
|
||||
Ok(path) => path,
|
||||
Err(_) => return Ok(false),
|
||||
};
|
||||
let requested_canonical = std::fs::canonicalize(requested_abs.as_path()).ok();
|
||||
for allowed in &guard.config.network_proxy.policy.allow_unix_sockets {
|
||||
if allowed == path {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Best-effort canonicalization to reduce surprises with symlinks.
|
||||
// If canonicalization fails (e.g., socket not created yet), fall back to raw comparison.
|
||||
let Some(requested_canonical) = &requested_canonical else {
|
||||
continue;
|
||||
};
|
||||
if let Ok(allowed_canonical) = std::fs::canonicalize(allowed)
|
||||
&& &allowed_canonical == requested_canonical
|
||||
{
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
pub async fn method_allowed(&self, method: &str) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.network_proxy.mode.allows_method(method))
|
||||
}
|
||||
|
||||
pub async fn allow_upstream_proxy(&self) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.network_proxy.allow_upstream_proxy)
|
||||
}
|
||||
|
||||
pub async fn network_mode(&self) -> Result<NetworkMode> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.network_proxy.mode)
|
||||
}
|
||||
|
||||
pub async fn set_network_mode(&self, mode: NetworkMode) -> Result<()> {
|
||||
loop {
|
||||
self.reload_if_needed().await?;
|
||||
let (candidate, constraints) = {
|
||||
let guard = self.state.read().await;
|
||||
let mut candidate = guard.config.clone();
|
||||
candidate.network_proxy.mode = mode;
|
||||
(candidate, guard.constraints.clone())
|
||||
};
|
||||
|
||||
validate_policy_against_constraints(&candidate, &constraints)
|
||||
.context("network_proxy.mode constrained by managed config")?;
|
||||
|
||||
let mut guard = self.state.write().await;
|
||||
if guard.constraints != constraints {
|
||||
drop(guard);
|
||||
continue;
|
||||
}
|
||||
guard.config.network_proxy.mode = mode;
|
||||
info!("updated network mode to {mode:?}");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
async fn reload_if_needed(&self) -> Result<()> {
|
||||
let needs_reload = {
|
||||
let guard = self.state.read().await;
|
||||
guard.layer_mtimes.iter().any(|layer| {
|
||||
let metadata = std::fs::metadata(&layer.path).ok();
|
||||
match (metadata.and_then(|m| m.modified().ok()), layer.mtime) {
|
||||
(Some(new_mtime), Some(old_mtime)) => new_mtime > old_mtime,
|
||||
(Some(_), None) => true,
|
||||
(None, Some(_)) => true,
|
||||
(None, None) => false,
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
if !needs_reload {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.force_reload().await
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unix_socket_permissions_supported() -> bool {
|
||||
cfg!(target_os = "macos")
|
||||
}
|
||||
|
||||
async fn host_resolves_to_non_public_ip(host: &str, port: u16) -> bool {
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return is_non_public_ip(ip);
|
||||
}
|
||||
|
||||
// If DNS lookup fails, default to "not local/private" rather than blocking. In practice, the
|
||||
// subsequent connect attempt will fail anyway, and blocking on transient resolver issues would
|
||||
// make the proxy fragile. The allowlist/denylist remains the primary control plane.
|
||||
let addrs = match timeout(DNS_LOOKUP_TIMEOUT, lookup_host((host, port))).await {
|
||||
Ok(Ok(addrs)) => addrs,
|
||||
Ok(Err(_)) | Err(_) => return false,
|
||||
};
|
||||
|
||||
for addr in addrs {
|
||||
if is_non_public_ip(addr.ip()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn log_policy_changes(previous: &NetworkProxyConfig, next: &NetworkProxyConfig) {
|
||||
log_domain_list_changes(
|
||||
"allowlist",
|
||||
&previous.network_proxy.policy.allowed_domains,
|
||||
&next.network_proxy.policy.allowed_domains,
|
||||
);
|
||||
log_domain_list_changes(
|
||||
"denylist",
|
||||
&previous.network_proxy.policy.denied_domains,
|
||||
&next.network_proxy.policy.denied_domains,
|
||||
);
|
||||
}
|
||||
|
||||
fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]) {
|
||||
let previous_set: HashSet<String> = previous
|
||||
.iter()
|
||||
.map(|entry| entry.to_ascii_lowercase())
|
||||
.collect();
|
||||
let next_set: HashSet<String> = next
|
||||
.iter()
|
||||
.map(|entry| entry.to_ascii_lowercase())
|
||||
.collect();
|
||||
|
||||
let added = next_set
|
||||
.difference(&previous_set)
|
||||
.cloned()
|
||||
.collect::<HashSet<_>>();
|
||||
let removed = previous_set
|
||||
.difference(&next_set)
|
||||
.cloned()
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let mut seen_next = HashSet::new();
|
||||
for entry in next {
|
||||
let key = entry.to_ascii_lowercase();
|
||||
if seen_next.insert(key.clone()) && added.contains(&key) {
|
||||
info!("config entry added to {list_name}: {entry}");
|
||||
}
|
||||
}
|
||||
|
||||
let mut seen_previous = HashSet::new();
|
||||
for entry in previous {
|
||||
let key = entry.to_ascii_lowercase();
|
||||
if seen_previous.insert(key.clone()) && removed.contains(&key) {
|
||||
info!("config entry removed from {list_name}: {entry}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_explicit_local_allowlisted(allowed_domains: &[String], host: &Host) -> bool {
|
||||
let normalized_host = host.as_str();
|
||||
allowed_domains.iter().any(|pattern| {
|
||||
let pattern = pattern.trim();
|
||||
if pattern == "*" || pattern.starts_with("*.") || pattern.starts_with("**.") {
|
||||
return false;
|
||||
}
|
||||
if pattern.contains('*') || pattern.contains('?') {
|
||||
return false;
|
||||
}
|
||||
normalize_host(pattern) == normalized_host
|
||||
})
|
||||
}
|
||||
|
||||
fn unix_timestamp() -> i64 {
|
||||
OffsetDateTime::now_utc().unix_timestamp()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn network_proxy_state_for_policy(
|
||||
policy: crate::config::NetworkPolicy,
|
||||
) -> NetworkProxyState {
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: crate::config::NetworkProxySettings {
|
||||
enabled: true,
|
||||
mode: NetworkMode::Full,
|
||||
policy,
|
||||
..crate::config::NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
let allow_set =
|
||||
crate::policy::compile_globset(&config.network_proxy.policy.allowed_domains).unwrap();
|
||||
let deny_set =
|
||||
crate::policy::compile_globset(&config.network_proxy.policy.denied_domains).unwrap();
|
||||
|
||||
let state = ConfigState {
|
||||
config,
|
||||
allow_set,
|
||||
deny_set,
|
||||
constraints: NetworkProxyConstraints::default(),
|
||||
layer_mtimes: Vec::new(),
|
||||
cfg_path: PathBuf::from("/nonexistent/config.toml"),
|
||||
blocked: VecDeque::new(),
|
||||
};
|
||||
|
||||
NetworkProxyState {
|
||||
state: Arc::new(RwLock::new(state)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::config::NetworkPolicy;
|
||||
use crate::config::NetworkProxyConfig;
|
||||
use crate::config::NetworkProxySettings;
|
||||
use crate::policy::compile_globset;
|
||||
use crate::state::NetworkProxyConstraints;
|
||||
use crate::state::validate_policy_against_constraints;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_denied_wins_over_allowed() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
denied_domains: vec!["example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("example.com", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::Denied)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_requires_allowlist_match() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("example.com", 80).await.unwrap(),
|
||||
HostBlockDecision::Allowed
|
||||
);
|
||||
assert_eq!(
|
||||
// Use a public IP literal to avoid relying on ambient DNS behavior (some networks
|
||||
// resolve unknown hostnames to private IPs, which would trigger `not_allowed_local`).
|
||||
state.host_blocked("8.8.8.8", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowed)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_subdomain_wildcards_exclude_apex() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["*.openai.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("api.openai.com", 80).await.unwrap(),
|
||||
HostBlockDecision::Allowed
|
||||
);
|
||||
assert_eq!(
|
||||
state.host_blocked("openai.com", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowed)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_loopback_when_local_binding_disabled() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("127.0.0.1", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
|
||||
);
|
||||
assert_eq!(
|
||||
state.host_blocked("localhost", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_loopback_when_allowlist_is_wildcard() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["*".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("127.0.0.1", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_private_ip_literal_when_allowlist_is_wildcard() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["*".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("10.0.0.1", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_allows_loopback_when_explicitly_allowlisted_and_local_binding_disabled() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["localhost".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("localhost", 80).await.unwrap(),
|
||||
HostBlockDecision::Allowed
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_allows_private_ip_literal_when_explicitly_allowlisted() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["10.0.0.1".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("10.0.0.1", 80).await.unwrap(),
|
||||
HostBlockDecision::Allowed
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_scoped_ipv6_literal_when_not_allowlisted() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("fe80::1%lo0", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_allows_scoped_ipv6_literal_when_explicitly_allowlisted() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["fe80::1%lo0".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("fe80::1%lo0", 80).await.unwrap(),
|
||||
HostBlockDecision::Allowed
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_private_ip_literals_when_local_binding_disabled() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("10.0.0.1", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_loopback_when_allowlist_empty() {
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec![],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("127.0.0.1", 80).await.unwrap(),
|
||||
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_widening_allowed_domains() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["example.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string(), "evil.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_widening_mode() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
mode: Some(NetworkMode::Limited),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
mode: NetworkMode::Full,
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_allows_narrowing_wildcard_allowlist() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["*.example.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allowed_domains: vec!["api.example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_rejects_widening_wildcard_allowlist() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["*.example.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allowed_domains: vec!["**.example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_requires_managed_denied_domains_entries() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
denied_domains: Some(vec!["evil.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
denied_domains: vec![],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_enabling_when_managed_disabled() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
enabled: Some(false),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_allow_local_binding_when_managed_disabled() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allow_local_binding: Some(false),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allow_local_binding: true,
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_non_loopback_admin_without_managed_opt_in() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
dangerously_allow_non_loopback_admin: Some(false),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
dangerously_allow_non_loopback_admin: true,
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_allows_non_loopback_admin_with_managed_opt_in() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
dangerously_allow_non_loopback_admin: Some(true),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = NetworkProxyConfig {
|
||||
network_proxy: NetworkProxySettings {
|
||||
enabled: true,
|
||||
dangerously_allow_non_loopback_admin: true,
|
||||
..NetworkProxySettings::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_is_case_insensitive() {
|
||||
let patterns = vec!["ExAmPle.CoM".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("example.com"));
|
||||
assert!(set.is_match("EXAMPLE.COM"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_excludes_apex_for_subdomain_patterns() {
|
||||
let patterns = vec!["*.openai.com".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("api.openai.com"));
|
||||
assert!(!set.is_match("openai.com"));
|
||||
assert!(!set.is_match("evilopenai.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_includes_apex_for_double_wildcard_patterns() {
|
||||
let patterns = vec!["**.openai.com".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("openai.com"));
|
||||
assert!(set.is_match("api.openai.com"));
|
||||
assert!(!set.is_match("evilopenai.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_matches_all_with_star() {
|
||||
let patterns = vec!["*".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("openai.com"));
|
||||
assert!(set.is_match("api.openai.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_dedupes_patterns_without_changing_behavior() {
|
||||
let patterns = vec!["example.com".to_string(), "example.com".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("example.com"));
|
||||
assert!(set.is_match("EXAMPLE.COM"));
|
||||
assert!(!set.is_match("not-example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_rejects_invalid_patterns() {
|
||||
let patterns = vec!["[".to_string()];
|
||||
assert!(compile_globset(&patterns).is_err());
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn unix_socket_allowlist_is_respected_on_macos() {
|
||||
let socket_path = "/tmp/example.sock".to_string();
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_unix_sockets: vec![socket_path.clone()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert!(state.is_unix_socket_allowed(&socket_path).await.unwrap());
|
||||
assert!(
|
||||
!state
|
||||
.is_unix_socket_allowed("/tmp/not-allowed.sock")
|
||||
.await
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn unix_socket_allowlist_resolves_symlinks() {
|
||||
use std::os::unix::fs::symlink;
|
||||
use tempfile::tempdir;
|
||||
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let dir = temp_dir.path();
|
||||
|
||||
let real = dir.join("real.sock");
|
||||
let link = dir.join("link.sock");
|
||||
|
||||
// The allowlist mechanism is path-based; for test purposes we don't need an actual unix
|
||||
// domain socket. Any filesystem entry works for canonicalization.
|
||||
std::fs::write(&real, b"not a socket").unwrap();
|
||||
symlink(&real, &link).unwrap();
|
||||
|
||||
let real_s = real.to_str().unwrap().to_string();
|
||||
let link_s = link.to_str().unwrap().to_string();
|
||||
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_unix_sockets: vec![real_s],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert!(state.is_unix_socket_allowed(&link_s).await.unwrap());
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
#[tokio::test]
|
||||
async fn unix_socket_allowlist_is_rejected_on_non_macos() {
|
||||
let socket_path = "/tmp/example.sock".to_string();
|
||||
let state = network_proxy_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_unix_sockets: vec![socket_path.clone()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert!(!state.is_unix_socket_allowed(&socket_path).await.unwrap());
|
||||
}
|
||||
}
|
||||
419
codex-rs/network-proxy/src/state.rs
Normal file
419
codex-rs/network-proxy/src/state.rs
Normal file
@@ -0,0 +1,419 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::config::NetworkProxyConfig;
|
||||
use crate::policy::DomainPattern;
|
||||
use crate::policy::compile_globset;
|
||||
use crate::runtime::ConfigState;
|
||||
use crate::runtime::LayerMtime;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
use codex_core::config::CONFIG_TOML_FILE;
|
||||
use codex_core::config::Constrained;
|
||||
use codex_core::config::ConstraintError;
|
||||
use codex_core::config::find_codex_home;
|
||||
use codex_core::config_loader::ConfigLayerStack;
|
||||
use codex_core::config_loader::ConfigLayerStackOrdering;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_core::config_loader::RequirementSource;
|
||||
use codex_core::config_loader::load_config_layers_state;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashSet;
|
||||
|
||||
pub use crate::runtime::BlockedRequest;
|
||||
pub use crate::runtime::NetworkProxyState;
|
||||
#[cfg(test)]
|
||||
pub(crate) use crate::runtime::network_proxy_state_for_policy;
|
||||
|
||||
pub(crate) async fn build_config_state() -> Result<ConfigState> {
|
||||
// Load config through `codex-core` so we inherit the same layer ordering and semantics as the
|
||||
// rest of Codex (system/managed layers, user layers, session flags, etc.).
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let cli_overrides = Vec::new();
|
||||
let overrides = LoaderOverrides::default();
|
||||
let config_layer_stack = load_config_layers_state(&codex_home, None, &cli_overrides, overrides)
|
||||
.await
|
||||
.context("failed to load Codex config")?;
|
||||
|
||||
let cfg_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
|
||||
// Deserialize from the merged effective config, rather than parsing config.toml ourselves.
|
||||
// This avoids a second parser/merger implementation (and the drift that comes with it).
|
||||
let merged_toml = config_layer_stack.effective_config();
|
||||
let config: NetworkProxyConfig = merged_toml
|
||||
.try_into()
|
||||
.context("failed to deserialize network proxy config")?;
|
||||
|
||||
// Security boundary: user-controlled layers must not be able to widen restrictions set by
|
||||
// trusted/managed layers (e.g., MDM). Enforce this before building runtime state.
|
||||
let constraints = enforce_trusted_constraints(&config_layer_stack, &config)?;
|
||||
|
||||
let layer_mtimes = collect_layer_mtimes(&config_layer_stack);
|
||||
let deny_set = compile_globset(&config.network_proxy.policy.denied_domains)?;
|
||||
let allow_set = compile_globset(&config.network_proxy.policy.allowed_domains)?;
|
||||
Ok(ConfigState {
|
||||
config,
|
||||
allow_set,
|
||||
deny_set,
|
||||
constraints,
|
||||
layer_mtimes,
|
||||
cfg_path,
|
||||
blocked: std::collections::VecDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_layer_mtimes(stack: &ConfigLayerStack) -> Vec<LayerMtime> {
|
||||
stack
|
||||
.get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, false)
|
||||
.iter()
|
||||
.filter_map(|layer| {
|
||||
let path = match &layer.name {
|
||||
ConfigLayerSource::System { file } => Some(file.as_path().to_path_buf()),
|
||||
ConfigLayerSource::User { file } => Some(file.as_path().to_path_buf()),
|
||||
ConfigLayerSource::Project { dot_codex_folder } => dot_codex_folder
|
||||
.join(CONFIG_TOML_FILE)
|
||||
.ok()
|
||||
.map(|p| p.as_path().to_path_buf()),
|
||||
ConfigLayerSource::LegacyManagedConfigTomlFromFile { file } => {
|
||||
Some(file.as_path().to_path_buf())
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
path.map(LayerMtime::new)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialConfig {
|
||||
#[serde(default)]
|
||||
network_proxy: PartialNetworkProxyConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialNetworkProxyConfig {
|
||||
enabled: Option<bool>,
|
||||
mode: Option<NetworkMode>,
|
||||
allow_upstream_proxy: Option<bool>,
|
||||
dangerously_allow_non_loopback_proxy: Option<bool>,
|
||||
dangerously_allow_non_loopback_admin: Option<bool>,
|
||||
#[serde(default)]
|
||||
policy: PartialNetworkPolicy,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialNetworkPolicy {
|
||||
#[serde(default)]
|
||||
allowed_domains: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
denied_domains: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
allow_unix_sockets: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
allow_local_binding: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct NetworkProxyConstraints {
|
||||
pub(crate) enabled: Option<bool>,
|
||||
pub(crate) mode: Option<NetworkMode>,
|
||||
pub(crate) allow_upstream_proxy: Option<bool>,
|
||||
pub(crate) dangerously_allow_non_loopback_proxy: Option<bool>,
|
||||
pub(crate) dangerously_allow_non_loopback_admin: Option<bool>,
|
||||
pub(crate) allowed_domains: Option<Vec<String>>,
|
||||
pub(crate) denied_domains: Option<Vec<String>>,
|
||||
pub(crate) allow_unix_sockets: Option<Vec<String>>,
|
||||
pub(crate) allow_local_binding: Option<bool>,
|
||||
}
|
||||
|
||||
fn enforce_trusted_constraints(
|
||||
layers: &codex_core::config_loader::ConfigLayerStack,
|
||||
config: &NetworkProxyConfig,
|
||||
) -> Result<NetworkProxyConstraints> {
|
||||
let constraints = network_proxy_constraints_from_trusted_layers(layers)?;
|
||||
validate_policy_against_constraints(config, &constraints)
|
||||
.context("network proxy constraints")?;
|
||||
Ok(constraints)
|
||||
}
|
||||
|
||||
fn network_proxy_constraints_from_trusted_layers(
|
||||
layers: &codex_core::config_loader::ConfigLayerStack,
|
||||
) -> Result<NetworkProxyConstraints> {
|
||||
let mut constraints = NetworkProxyConstraints::default();
|
||||
for layer in layers.get_layers(
|
||||
codex_core::config_loader::ConfigLayerStackOrdering::LowestPrecedenceFirst,
|
||||
false,
|
||||
) {
|
||||
// Only trusted layers contribute constraints. User-controlled layers can narrow policy but
|
||||
// must never widen beyond what managed config allows.
|
||||
if is_user_controlled_layer(&layer.name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let partial: PartialConfig = layer
|
||||
.config
|
||||
.clone()
|
||||
.try_into()
|
||||
.context("failed to deserialize trusted config layer")?;
|
||||
|
||||
if let Some(enabled) = partial.network_proxy.enabled {
|
||||
constraints.enabled = Some(enabled);
|
||||
}
|
||||
if let Some(mode) = partial.network_proxy.mode {
|
||||
constraints.mode = Some(mode);
|
||||
}
|
||||
if let Some(allow_upstream_proxy) = partial.network_proxy.allow_upstream_proxy {
|
||||
constraints.allow_upstream_proxy = Some(allow_upstream_proxy);
|
||||
}
|
||||
if let Some(dangerously_allow_non_loopback_proxy) =
|
||||
partial.network_proxy.dangerously_allow_non_loopback_proxy
|
||||
{
|
||||
constraints.dangerously_allow_non_loopback_proxy =
|
||||
Some(dangerously_allow_non_loopback_proxy);
|
||||
}
|
||||
if let Some(dangerously_allow_non_loopback_admin) =
|
||||
partial.network_proxy.dangerously_allow_non_loopback_admin
|
||||
{
|
||||
constraints.dangerously_allow_non_loopback_admin =
|
||||
Some(dangerously_allow_non_loopback_admin);
|
||||
}
|
||||
|
||||
if let Some(allowed_domains) = partial.network_proxy.policy.allowed_domains {
|
||||
constraints.allowed_domains = Some(allowed_domains);
|
||||
}
|
||||
if let Some(denied_domains) = partial.network_proxy.policy.denied_domains {
|
||||
constraints.denied_domains = Some(denied_domains);
|
||||
}
|
||||
if let Some(allow_unix_sockets) = partial.network_proxy.policy.allow_unix_sockets {
|
||||
constraints.allow_unix_sockets = Some(allow_unix_sockets);
|
||||
}
|
||||
if let Some(allow_local_binding) = partial.network_proxy.policy.allow_local_binding {
|
||||
constraints.allow_local_binding = Some(allow_local_binding);
|
||||
}
|
||||
}
|
||||
Ok(constraints)
|
||||
}
|
||||
|
||||
fn is_user_controlled_layer(layer: &ConfigLayerSource) -> bool {
|
||||
matches!(
|
||||
layer,
|
||||
ConfigLayerSource::User { .. }
|
||||
| ConfigLayerSource::Project { .. }
|
||||
| ConfigLayerSource::SessionFlags
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn validate_policy_against_constraints(
|
||||
config: &NetworkProxyConfig,
|
||||
constraints: &NetworkProxyConstraints,
|
||||
) -> std::result::Result<(), ConstraintError> {
|
||||
fn invalid_value(
|
||||
field_name: &'static str,
|
||||
candidate: impl Into<String>,
|
||||
allowed: impl Into<String>,
|
||||
) -> ConstraintError {
|
||||
ConstraintError::InvalidValue {
|
||||
field_name,
|
||||
candidate: candidate.into(),
|
||||
allowed: allowed.into(),
|
||||
requirement_source: RequirementSource::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
let enabled = config.network_proxy.enabled;
|
||||
if let Some(max_enabled) = constraints.enabled {
|
||||
let _ = Constrained::new(enabled, move |candidate| {
|
||||
if *candidate && !max_enabled {
|
||||
Err(invalid_value(
|
||||
"network_proxy.enabled",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
if let Some(max_mode) = constraints.mode {
|
||||
let _ = Constrained::new(config.network_proxy.mode, move |candidate| {
|
||||
if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
|
||||
Err(invalid_value(
|
||||
"network_proxy.mode",
|
||||
format!("{candidate:?}"),
|
||||
format!("{max_mode:?} or more restrictive"),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
let allow_upstream_proxy = constraints.allow_upstream_proxy;
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.allow_upstream_proxy,
|
||||
move |candidate| match allow_upstream_proxy {
|
||||
Some(true) | None => Ok(()),
|
||||
Some(false) => {
|
||||
if *candidate {
|
||||
Err(invalid_value(
|
||||
"network_proxy.allow_upstream_proxy",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
let allow_non_loopback_admin = constraints.dangerously_allow_non_loopback_admin;
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.dangerously_allow_non_loopback_admin,
|
||||
move |candidate| match allow_non_loopback_admin {
|
||||
Some(true) | None => Ok(()),
|
||||
Some(false) => {
|
||||
if *candidate {
|
||||
Err(invalid_value(
|
||||
"network_proxy.dangerously_allow_non_loopback_admin",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
let allow_non_loopback_proxy = constraints.dangerously_allow_non_loopback_proxy;
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.dangerously_allow_non_loopback_proxy,
|
||||
move |candidate| match allow_non_loopback_proxy {
|
||||
Some(true) | None => Ok(()),
|
||||
Some(false) => {
|
||||
if *candidate {
|
||||
Err(invalid_value(
|
||||
"network_proxy.dangerously_allow_non_loopback_proxy",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
if let Some(allow_local_binding) = constraints.allow_local_binding {
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allow_local_binding,
|
||||
move |candidate| {
|
||||
if *candidate && !allow_local_binding {
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allow_local_binding",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(allowed_domains) = &constraints.allowed_domains {
|
||||
let managed_patterns: Vec<DomainPattern> = allowed_domains
|
||||
.iter()
|
||||
.map(|entry| DomainPattern::parse_for_constraints(entry))
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allowed_domains.clone(),
|
||||
move |candidate| {
|
||||
let mut invalid = Vec::new();
|
||||
for entry in candidate {
|
||||
let candidate_pattern = DomainPattern::parse_for_constraints(entry);
|
||||
if !managed_patterns
|
||||
.iter()
|
||||
.any(|managed| managed.allows(&candidate_pattern))
|
||||
{
|
||||
invalid.push(entry.clone());
|
||||
}
|
||||
}
|
||||
if invalid.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allowed_domains",
|
||||
format!("{invalid:?}"),
|
||||
"subset of managed allowed_domains",
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(denied_domains) = &constraints.denied_domains {
|
||||
let required_set: HashSet<String> = denied_domains
|
||||
.iter()
|
||||
.map(|s| s.to_ascii_lowercase())
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.denied_domains.clone(),
|
||||
move |candidate| {
|
||||
let candidate_set: HashSet<String> =
|
||||
candidate.iter().map(|s| s.to_ascii_lowercase()).collect();
|
||||
let missing: Vec<String> = required_set
|
||||
.iter()
|
||||
.filter(|entry| !candidate_set.contains(*entry))
|
||||
.cloned()
|
||||
.collect();
|
||||
if missing.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.denied_domains",
|
||||
"missing managed denied_domains entries",
|
||||
format!("{missing:?}"),
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(allow_unix_sockets) = &constraints.allow_unix_sockets {
|
||||
let allowed_set: HashSet<String> = allow_unix_sockets
|
||||
.iter()
|
||||
.map(|s| s.to_ascii_lowercase())
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allow_unix_sockets.clone(),
|
||||
move |candidate| {
|
||||
let mut invalid = Vec::new();
|
||||
for entry in candidate {
|
||||
if !allowed_set.contains(&entry.to_ascii_lowercase()) {
|
||||
invalid.push(entry.clone());
|
||||
}
|
||||
}
|
||||
if invalid.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allow_unix_sockets",
|
||||
format!("{invalid:?}"),
|
||||
"subset of managed allow_unix_sockets",
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn network_mode_rank(mode: NetworkMode) -> u8 {
|
||||
match mode {
|
||||
NetworkMode::Limited => 0,
|
||||
NetworkMode::Full => 1,
|
||||
}
|
||||
}
|
||||
188
codex-rs/network-proxy/src/upstream.rs
Normal file
188
codex-rs/network-proxy/src/upstream.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
use rama_core::Layer;
|
||||
use rama_core::Service;
|
||||
use rama_core::error::BoxError;
|
||||
use rama_core::error::ErrorContext as _;
|
||||
use rama_core::error::OpaqueError;
|
||||
use rama_core::extensions::ExtensionsMut;
|
||||
use rama_core::extensions::ExtensionsRef;
|
||||
use rama_core::service::BoxService;
|
||||
use rama_http::Body;
|
||||
use rama_http::Request;
|
||||
use rama_http::Response;
|
||||
use rama_http::layer::version_adapter::RequestVersionAdapter;
|
||||
use rama_http_backend::client::HttpClientService;
|
||||
use rama_http_backend::client::HttpConnector;
|
||||
use rama_http_backend::client::proxy::layer::HttpProxyConnectorLayer;
|
||||
use rama_net::address::ProxyAddress;
|
||||
use rama_net::client::EstablishedClientConnection;
|
||||
use rama_net::http::RequestContext;
|
||||
use rama_tcp::client::service::TcpConnector;
|
||||
use rama_tls_boring::client::TlsConnectorDataBuilder;
|
||||
use rama_tls_boring::client::TlsConnectorLayer;
|
||||
use tracing::warn;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
use rama_unix::client::UnixConnector;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct ProxyConfig {
|
||||
http: Option<ProxyAddress>,
|
||||
https: Option<ProxyAddress>,
|
||||
all: Option<ProxyAddress>,
|
||||
}
|
||||
|
||||
impl ProxyConfig {
|
||||
fn from_env() -> Self {
|
||||
let http = read_proxy_env(&["HTTP_PROXY", "http_proxy"]);
|
||||
let https = read_proxy_env(&["HTTPS_PROXY", "https_proxy"]);
|
||||
let all = read_proxy_env(&["ALL_PROXY", "all_proxy"]);
|
||||
Self { http, https, all }
|
||||
}
|
||||
|
||||
fn proxy_for_request(&self, req: &Request) -> Option<ProxyAddress> {
|
||||
let is_secure = RequestContext::try_from(req)
|
||||
.map(|ctx| ctx.protocol.is_secure())
|
||||
.unwrap_or(false);
|
||||
self.proxy_for_protocol(is_secure)
|
||||
}
|
||||
|
||||
fn proxy_for_protocol(&self, is_secure: bool) -> Option<ProxyAddress> {
|
||||
if is_secure {
|
||||
self.https
|
||||
.clone()
|
||||
.or_else(|| self.http.clone())
|
||||
.or_else(|| self.all.clone())
|
||||
} else {
|
||||
self.http.clone().or_else(|| self.all.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_proxy_env(keys: &[&str]) -> Option<ProxyAddress> {
|
||||
for key in keys {
|
||||
let Ok(value) = std::env::var(key) else {
|
||||
continue;
|
||||
};
|
||||
let value = value.trim();
|
||||
if value.is_empty() {
|
||||
continue;
|
||||
}
|
||||
match ProxyAddress::try_from(value) {
|
||||
Ok(proxy) => {
|
||||
if proxy
|
||||
.protocol
|
||||
.as_ref()
|
||||
.map(rama_net::Protocol::is_http)
|
||||
.unwrap_or(true)
|
||||
{
|
||||
return Some(proxy);
|
||||
}
|
||||
warn!("ignoring {key}: non-http proxy protocol");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("ignoring {key}: invalid proxy address ({err})");
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn proxy_for_connect() -> Option<ProxyAddress> {
|
||||
ProxyConfig::from_env().proxy_for_protocol(true)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct UpstreamClient {
|
||||
connector: BoxService<
|
||||
Request<Body>,
|
||||
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
|
||||
BoxError,
|
||||
>,
|
||||
proxy_config: ProxyConfig,
|
||||
}
|
||||
|
||||
impl UpstreamClient {
|
||||
pub(crate) fn direct() -> Self {
|
||||
Self::new(ProxyConfig::default())
|
||||
}
|
||||
|
||||
pub(crate) fn from_env_proxy() -> Self {
|
||||
Self::new(ProxyConfig::from_env())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub(crate) fn unix_socket(path: &str) -> Self {
|
||||
let connector = build_unix_connector(path);
|
||||
Self {
|
||||
connector,
|
||||
proxy_config: ProxyConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn new(proxy_config: ProxyConfig) -> Self {
|
||||
let connector = build_http_connector();
|
||||
Self {
|
||||
connector,
|
||||
proxy_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Service<Request<Body>> for UpstreamClient {
|
||||
type Output = Response;
|
||||
type Error = OpaqueError;
|
||||
|
||||
async fn serve(&self, mut req: Request<Body>) -> Result<Self::Output, Self::Error> {
|
||||
if let Some(proxy) = self.proxy_config.proxy_for_request(&req) {
|
||||
req.extensions_mut().insert(proxy);
|
||||
}
|
||||
|
||||
let uri = req.uri().clone();
|
||||
let EstablishedClientConnection {
|
||||
input: mut req,
|
||||
conn: http_connection,
|
||||
} = self
|
||||
.connector
|
||||
.serve(req)
|
||||
.await
|
||||
.map_err(OpaqueError::from_boxed)?;
|
||||
|
||||
req.extensions_mut()
|
||||
.extend(http_connection.extensions().clone());
|
||||
|
||||
http_connection
|
||||
.serve(req)
|
||||
.await
|
||||
.map_err(OpaqueError::from_boxed)
|
||||
.with_context(|| format!("http request failure for uri: {uri}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_http_connector() -> BoxService<
|
||||
Request<Body>,
|
||||
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
|
||||
BoxError,
|
||||
> {
|
||||
let transport = TcpConnector::default();
|
||||
let proxy = HttpProxyConnectorLayer::optional().into_layer(transport);
|
||||
let tls_config = TlsConnectorDataBuilder::new_http_auto().into_shared_builder();
|
||||
let tls = TlsConnectorLayer::auto()
|
||||
.with_connector_data(tls_config)
|
||||
.into_layer(proxy);
|
||||
let tls = RequestVersionAdapter::new(tls);
|
||||
let connector = HttpConnector::new(tls);
|
||||
connector.boxed()
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
fn build_unix_connector(
|
||||
path: &str,
|
||||
) -> BoxService<
|
||||
Request<Body>,
|
||||
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
|
||||
BoxError,
|
||||
> {
|
||||
let transport = UnixConnector::fixed(path);
|
||||
let connector = HttpConnector::new(transport);
|
||||
connector.boxed()
|
||||
}
|
||||
Reference in New Issue
Block a user