Compare commits

...

67 Commits

Author SHA1 Message Date
viyatb-oai
3788afd19d Add MITM support to network proxy 2026-01-22 22:50:45 -08:00
viyatb-oai
5738cf125c network-proxy: add SOCKS5 listener 2026-01-22 16:23:10 -08:00
viyatb-oai
6cb436a4e2 network-proxy: focus PR1 on core http + policy 2026-01-22 14:48:03 -08:00
viyatb-oai
4196294c26 network-proxy: fix clippy test literals 2026-01-22 11:21:46 -08:00
viyatb-oai
d54757bbbd network-proxy: harden policy enforcement 2026-01-22 08:22:36 -08:00
viyatb-oai
fc33c31258 ci: call musl setup script via GITHUB_WORKSPACE 2026-01-21 16:45:23 -08:00
viyatb-oai
4a0c292de3 ci: dedupe musl install steps 2026-01-21 16:38:05 -08:00
viyatb-oai
1dd695246f remove cmake and clang setup from bazel.yml 2026-01-21 16:15:10 -08:00
viyatb-oai
872d0ae3db Adjust runtime formatting for rustfmt 2026-01-21 08:17:34 -08:00
viyatb-oai
5d7f98a5a8 Split network proxy state into runtime and policy modules 2026-01-21 08:04:01 -08:00
viyatb-oai
58562a2a43 Merge branch 'main' into pr/network-proxy-crate 2026-01-21 07:45:55 -08:00
viyatb-oai
4995f09c47 refactor state.rs into manageable modules 2026-01-21 00:39:15 -08:00
viyatb-oai
e3d19064be Harden local binding checks for IPv6 literals 2026-01-20 22:08:51 -08:00
viyatb-oai
e4c003d108 Honor proxy enablement and local binding rules 2026-01-20 21:36:01 -08:00
viyatb-oai
bcdedf5211 update messaging 2026-01-20 13:22:44 -08:00
viyatb-oai
f1cc7fbae8 Clamp proxy binds when unix sockets enabled 2026-01-20 10:43:29 -08:00
viyatb-oai
90c24700ac Fix CONNECT proxy handling and enforce managed network constraints 2026-01-20 08:21:54 -08:00
viyatb-oai
741b661cfa Tighten domain policy matching 2026-01-19 23:23:56 -08:00
viyatb-oai
8637043a0c Revert "Disable aws-lc bindgen in rama-crypto for Bazel"
This reverts commit c656278537.
2026-01-19 23:13:56 -08:00
viyatb-oai
fe1c1c859f Use clang as musl C compiler 2026-01-19 22:43:22 -08:00
viyatb-oai
c8b7c0091a Force pthreads for musl CMake 2026-01-19 22:29:39 -08:00
viyatb-oai
57c971470d Use clang++ for musl C++ headers 2026-01-19 22:15:49 -08:00
viyatb-oai
5d6611170b Install g++ for musl C++ headers 2026-01-19 22:00:41 -08:00
viyatb-oai
3d1e12b49e Fix musl compiler path quoting 2026-01-19 21:48:15 -08:00
viyatb-oai
7f44c725fb Fallback to musl-gcc for musl CXX 2026-01-19 21:36:44 -08:00
viyatb-oai
e8cff7ef77 Install musl g++ for CI 2026-01-19 21:19:33 -08:00
viyatb-oai
d85717dcf8 use individual rama crates and boring-ssl 2026-01-19 20:52:09 -08:00
viyatb-oai
c656278537 Disable aws-lc bindgen in rama-crypto for Bazel 2026-01-18 19:36:20 -08:00
viyatb-oai
8338bebc32 Make upstream proxy opt-in 2026-01-17 21:53:30 -08:00
viyatb-oai
302e6eea7d Revert cargo-bin Cargo.toml to origin/main 2026-01-16 23:24:53 -08:00
viyatb-oai
bd0ff89517 Revert cargo-bin fallback to origin/main 2026-01-16 23:23:18 -08:00
viyatb-oai
cbb5f48ba3 Remove metadata extraction from proxy 2026-01-16 23:17:41 -08:00
viyatb-oai
6c1df8b73e Merge origin/main 2026-01-16 11:30:45 -08:00
viyatb-oai
be94fb6913 CI: install libclang for Bazel 2026-01-15 14:01:27 -08:00
viyatb-oai
8f6413ce5d Bazel: disable cmake for aws-lc-sys 2026-01-15 13:11:41 -08:00
viyatb-oai
a61ab56b77 Merge origin/main into pr/network-proxy-crate 2026-01-14 17:33:14 -08:00
viyatb-oai
258b7ecdbd Revert "Bazel: disable cmake for aws-lc-sys"
This reverts commit 826e40683e.
2026-01-13 19:14:56 -08:00
viyatb-oai
b49b83847d Revert "Bazel: drop aws-lc bindgen in rama-crypto"
This reverts commit 981c7c3261.
2026-01-13 19:14:49 -08:00
viyatb-oai
ab28660a52 Revert "Patch rama-crypto to drop bindgen"
This reverts commit 74d748cefb.
2026-01-13 19:14:45 -08:00
viyatb-oai
e6194d5c89 Revert "Bazel: skip aws-lc-sys memcmp check"
This reverts commit 0dd709317f.
2026-01-13 19:14:40 -08:00
viyatb-oai
0bbe48c03e Revert "Fix aws-lc-sys patch hunk header"
This reverts commit 1906a23fa1.
2026-01-13 19:14:25 -08:00
viyatb-oai
1906a23fa1 Fix aws-lc-sys patch hunk header 2026-01-12 18:39:29 -08:00
viyatb-oai
0dd709317f Bazel: skip aws-lc-sys memcmp check 2026-01-12 18:32:07 -08:00
viyatb-oai
74d748cefb Patch rama-crypto to drop bindgen 2026-01-12 18:21:03 -08:00
viyatb-oai
981c7c3261 Bazel: drop aws-lc bindgen in rama-crypto 2026-01-12 18:05:12 -08:00
viyatb-oai
826e40683e Bazel: disable cmake for aws-lc-sys 2026-01-12 17:52:32 -08:00
viyatb-oai
a60515bc85 Install cmake for Bazel CI 2026-01-12 17:39:45 -08:00
viyatb-oai
6ef1dd9917 Remove vendored rama-tls-rustls 2026-01-12 17:25:19 -08:00
viyatb-oai
ef2c2d3131 Fix CI: cargo-shear, cargo-deny, bazel 2026-01-12 16:32:33 -08:00
viyatb-oai
d2042b92b6 Update network proxy rama deps 2026-01-12 15:25:25 -08:00
viyatb-oai
310c79eef5 Merge branch 'main' into pr/network-proxy-crate 2026-01-12 12:12:53 -08:00
viyatb-oai
ee102bcb63 fix test 2025-12-24 11:41:24 -08:00
viyatb-oai
3e9046128a adding back assert_cmd 2025-12-24 10:58:10 -08:00
viyatb-oai
e60d43c3a7 fix cargo shear 2025-12-24 10:16:47 -08:00
viyatb-oai
4f3097b585 Merge branch 'main' into pr/network-proxy-crate 2025-12-24 09:50:20 -08:00
viyatb-oai
9b2a353e6e explicitly name controls 2025-12-23 23:26:26 -08:00
viyatb-oai
10abb38b53 tighten escape mechanisms 2025-12-23 23:18:47 -08:00
viyatb-oai
2d7980340d add comments 2025-12-23 18:40:20 -08:00
viyatb-oai
6f4edec9f1 consolidate docs 2025-12-23 18:34:53 -08:00
viyatb-oai
fc35891b07 fix old artifacts from refactor 2025-12-23 18:27:39 -08:00
viyatb-oai
dc063ff890 add unit tests and re-add crate back to cargo 2025-12-23 18:19:42 -08:00
viyatb-oai
9d473922e3 address feedback 2025-12-23 18:03:55 -08:00
viyatb-oai
127b89b4ed Merge branch 'main' into pr/network-proxy-crate 2025-12-23 15:34:21 -08:00
viyatb-oai
9b20af68f0 use a general path 2025-12-22 16:01:55 -08:00
viyatb-oai
83e8a702fb use rama instead of implementing our own proxy stack 2025-12-21 21:49:51 -08:00
viyatb-oai
eceb76bf3d use better examples 2025-12-21 14:17:12 -08:00
viyatb-oai
f65edf9c91 Add codex-network-proxy crate 2025-12-21 12:15:59 -08:00
24 changed files with 5843 additions and 24 deletions

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env bash
set -euo pipefail
: "${TARGET:?TARGET environment variable is required}"
: "${GITHUB_ENV:?GITHUB_ENV environment variable is required}"
apt_update_args=()
if [[ -n "${APT_UPDATE_ARGS:-}" ]]; then
# shellcheck disable=SC2206
apt_update_args=(${APT_UPDATE_ARGS})
fi
apt_install_args=()
if [[ -n "${APT_INSTALL_ARGS:-}" ]]; then
# shellcheck disable=SC2206
apt_install_args=(${APT_INSTALL_ARGS})
fi
sudo apt-get update "${apt_update_args[@]}"
sudo apt-get install -y "${apt_install_args[@]}" musl-tools pkg-config g++ clang libc++-dev libc++abi-dev lld
case "${TARGET}" in
x86_64-unknown-linux-musl)
arch="x86_64"
;;
aarch64-unknown-linux-musl)
arch="aarch64"
;;
*)
echo "Unexpected musl target: ${TARGET}" >&2
exit 1
;;
esac
if command -v clang++ >/dev/null; then
cxx="$(command -v clang++)"
echo "CXXFLAGS=--target=${TARGET} -stdlib=libc++ -pthread" >> "$GITHUB_ENV"
echo "CFLAGS=--target=${TARGET} -pthread" >> "$GITHUB_ENV"
if command -v clang >/dev/null; then
cc="$(command -v clang)"
echo "CC=${cc}" >> "$GITHUB_ENV"
echo "TARGET_CC=${cc}" >> "$GITHUB_ENV"
target_cc_var="CC_${TARGET}"
target_cc_var="${target_cc_var//-/_}"
echo "${target_cc_var}=${cc}" >> "$GITHUB_ENV"
fi
elif command -v "${arch}-linux-musl-g++" >/dev/null; then
cxx="$(command -v "${arch}-linux-musl-g++")"
elif command -v musl-g++ >/dev/null; then
cxx="$(command -v musl-g++)"
elif command -v musl-gcc >/dev/null; then
cxx="$(command -v musl-gcc)"
echo "CFLAGS=-pthread" >> "$GITHUB_ENV"
else
echo "musl g++ not found after install; arch=${arch}" >&2
exit 1
fi
echo "CXX=${cxx}" >> "$GITHUB_ENV"
echo "CMAKE_CXX_COMPILER=${cxx}" >> "$GITHUB_ENV"
echo "CMAKE_ARGS=-DCMAKE_HAVE_THREADS_LIBRARY=1 -DCMAKE_USE_PTHREADS_INIT=1 -DCMAKE_THREAD_LIBS_INIT=-pthread -DTHREADS_PREFER_PTHREAD_FLAG=ON" >> "$GITHUB_ENV"

View File

@@ -265,11 +265,11 @@ jobs:
name: Install musl build tools
env:
DEBIAN_FRONTEND: noninteractive
TARGET: ${{ matrix.target }}
APT_UPDATE_ARGS: -o Acquire::Retries=3
APT_INSTALL_ARGS: --no-install-recommends
shell: bash
run: |
set -euo pipefail
sudo apt-get -y update -o Acquire::Retries=3
sudo apt-get -y install --no-install-recommends musl-tools pkg-config
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
- name: Install cargo-chef
if: ${{ matrix.profile == 'release' }}

View File

@@ -96,9 +96,9 @@ jobs:
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
name: Install musl build tools
run: |
sudo apt-get update
sudo apt-get install -y musl-tools pkg-config
env:
TARGET: ${{ matrix.target }}
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
- name: Cargo build
shell: bash

View File

@@ -99,9 +99,9 @@ jobs:
- if: ${{ matrix.install_musl }}
name: Install musl build dependencies
run: |
sudo apt-get update
sudo apt-get install -y musl-tools pkg-config
env:
TARGET: ${{ matrix.target }}
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
- name: Build exec server binaries
run: cargo build --release --target ${{ matrix.target }} --bin codex-exec-mcp-server --bin codex-execve-wrapper

1001
codex-rs/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -27,6 +27,7 @@ members = [
"login",
"mcp-server",
"mcp-types",
"network-proxy",
"ollama",
"process-hardening",
"protocol",
@@ -138,6 +139,7 @@ env-flags = "0.1.1"
env_logger = "0.11.5"
eventsource-stream = "0.2.3"
futures = { version = "0.3", default-features = false }
globset = "0.4"
http = "1.3.1"
icu_decimal = "2.1"
icu_locale_core = "2.1"

View File

@@ -73,6 +73,7 @@ ignore = [
{ id = "RUSTSEC-2024-0388", reason = "derivative is unmaintained; pulled in via starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
{ id = "RUSTSEC-2025-0057", reason = "fxhash is unmaintained; pulled in via starlark_map/starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
{ id = "RUSTSEC-2024-0436", reason = "paste is unmaintained; pulled in via ratatui/rmcp/starlark used by tui/execpolicy; no fixed release yet" },
{ id = "RUSTSEC-2025-0134", reason = "rustls-pemfile is unmaintained; pulled in via rama-tls-rustls used by codex-network-proxy; no safe upgrade until rama removes the dependency" },
# TODO(joshka, nornagon): remove this exception when once we update the ratatui fork to a version that uses lru 0.13+.
{ id = "RUSTSEC-2026-0002", reason = "lru 0.12.5 is pulled in via ratatui fork; cannot upgrade until the fork is updated" },
]

View File

@@ -0,0 +1,45 @@
[package]
name = "codex-network-proxy"
edition = "2024"
version = { workspace = true }
license.workspace = true
[[bin]]
name = "codex-network-proxy"
path = "src/main.rs"
[lib]
name = "codex_network_proxy"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
anyhow = { workspace = true }
async-trait = { workspace = true }
clap = { workspace = true, features = ["derive"] }
codex-app-server-protocol = { workspace = true }
codex-core = { workspace = true }
globset = { workspace = true }
rcgen-rama = { package = "rcgen", version = "0.14", default-features = false, features = ["pem", "x509-parser", "ring"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
time = { workspace = true }
tokio = { workspace = true, features = ["full"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["fmt"] }
rama-core = { version = "=0.3.0-alpha.4" }
rama-http = { version = "=0.3.0-alpha.4" }
rama-http-backend = { version = "=0.3.0-alpha.4", features = ["tls"] }
rama-net = { version = "=0.3.0-alpha.4", features = ["http", "tls"] }
rama-socks5 = { version = "=0.3.0-alpha.4" }
rama-tcp = { version = "=0.3.0-alpha.4", features = ["http"] }
rama-tls-boring = { version = "=0.3.0-alpha.4", features = ["http"] }
rama-utils = { version = "=0.3.0-alpha.4" }
[dev-dependencies]
pretty_assertions = { workspace = true }
[target.'cfg(target_family = "unix")'.dependencies]
rama-unix = { version = "=0.3.0-alpha.4" }

View File

@@ -0,0 +1,202 @@
# codex-network-proxy
`codex-network-proxy` is Codex's local network policy enforcement proxy. It runs:
- an HTTP proxy (default `127.0.0.1:3128`)
- a SOCKS5 proxy (default `127.0.0.1:8081`)
- an admin HTTP API (default `127.0.0.1:8080`)
It enforces an allow/deny policy and a "limited" mode intended for read-only network access.
## Quickstart
### 1) Configure
`codex-network-proxy` reads from Codex's merged `config.toml` (via `codex-core` config loading).
Example config:
```toml
[network_proxy]
enabled = true
proxy_url = "http://127.0.0.1:3128"
admin_url = "http://127.0.0.1:8080"
# SOCKS5 listens on 127.0.0.1:8081 by default. Override via `NetworkProxyBuilder::socks_addr`.
# When `enabled` is false, the proxy no-ops and does not bind listeners.
# When true, respect HTTP(S)_PROXY/ALL_PROXY for upstream requests (HTTP(S) proxies only),
# including CONNECT tunnels in full mode.
allow_upstream_proxy = false
# By default, non-loopback binds are clamped to loopback for safety.
# If you want to expose these listeners beyond localhost, you must opt in explicitly.
dangerously_allow_non_loopback_proxy = false
dangerously_allow_non_loopback_admin = false
mode = "limited" # or "full"
[network_proxy.mitm]
# When enabled, HTTPS CONNECT can be terminated so limited-mode method policy still applies.
# CA cert/key paths are relative to CODEX_HOME by default.
enabled = false
ca_cert_path = "network_proxy/mitm/ca.pem"
ca_key_path = "network_proxy/mitm/ca.key"
# Maximum size of request/response bodies MITM will buffer for inspection.
max_body_bytes = 1048576
[network_proxy.policy]
# Hosts must match the allowlist (unless denied).
# If `allowed_domains` is empty, the proxy blocks requests until an allowlist is configured.
allowed_domains = ["*.openai.com"]
denied_domains = ["evil.example"]
# If false, local/private networking is rejected. Explicit allowlisting of local IP literals
# (or `localhost`) is required to permit them.
# Hostnames that resolve to local/private IPs are still blocked even if allowlisted.
allow_local_binding = false
# macOS-only: allows proxying to a unix socket when request includes `x-unix-socket: /path`.
allow_unix_sockets = ["/tmp/example.sock"]
```
### 2) Run the proxy
```bash
cargo run -p codex-network-proxy --
```
If you plan to enable MITM, initialize the default directory first:
```bash
cargo run -p codex-network-proxy -- init
```
The proxy will generate a local CA on first MITM use if the files do not exist. Import the
generated CA cert into your system trust store to avoid TLS errors.
### 3) Point a client at it
For HTTP(S) traffic:
```bash
export HTTP_PROXY="http://127.0.0.1:3128"
export HTTPS_PROXY="http://127.0.0.1:3128"
```
For SOCKS5 traffic:
```bash
export ALL_PROXY="socks5h://127.0.0.1:8081"
```
To enable SOCKS5 UDP associate support:
```bash
cargo run -p codex-network-proxy -- --enable-socks5-udp
```
### 4) Understand blocks / debugging
When a request is blocked, the proxy responds with `403` and includes:
- `x-proxy-error`: one of:
- `blocked-by-allowlist`
- `blocked-by-denylist`
- `blocked-by-method-policy`
- `blocked-by-mitm-required`
- `blocked-by-policy`
In "limited" mode, only `GET`, `HEAD`, and `OPTIONS` are allowed. HTTPS CONNECT requests require
MITM to enforce limited-mode method policy; otherwise they are blocked.
## Library API
`codex-network-proxy` can be embedded as a library with a thin API:
```rust
use codex_network_proxy::{NetworkProxy, NetworkDecision, NetworkPolicyRequest};
let proxy = NetworkProxy::builder()
.http_addr("127.0.0.1:8080".parse()?)
.admin_addr("127.0.0.1:9000".parse()?)
.policy_decider(|request: NetworkPolicyRequest| async move {
// Example: auto-allow when exec policy already approved a command prefix.
if let Some(command) = request.command.as_deref() {
if command.starts_with("curl ") {
return NetworkDecision::Allow;
}
}
NetworkDecision::Deny {
reason: "policy_denied".to_string(),
}
})
.build()
.await?;
let handle = proxy.run().await?;
handle.shutdown().await?;
```
When unix socket proxying is enabled, HTTP/admin bind overrides are still clamped to loopback
to avoid turning the proxy into a remote bridge to local daemons.
### Policy hook (exec-policy mapping)
The proxy exposes a policy hook (`NetworkPolicyDecider`) that can override allowlist-only blocks.
It receives `command` and `exec_policy_hint` fields when supplied by the embedding app. This lets
core map exec approvals to network access, e.g. if a user already approved `curl *` for a session,
the decider can auto-allow network requests originating from that command.
**Important:** Explicit deny rules still win. The decider only gets a chance to override
`not_allowed` (allowlist misses), not `denied` or `not_allowed_local`.
## Admin API
The admin API is a small HTTP server intended for debugging and runtime adjustments.
Endpoints:
```bash
curl -sS http://127.0.0.1:8080/health
curl -sS http://127.0.0.1:8080/config
curl -sS http://127.0.0.1:8080/patterns
curl -sS http://127.0.0.1:8080/blocked
# Switch modes without restarting:
curl -sS -X POST http://127.0.0.1:8080/mode -d '{"mode":"full"}'
# Force a config reload:
curl -sS -X POST http://127.0.0.1:8080/reload
```
## Platform notes
- Unix socket proxying via the `x-unix-socket` header is **macOS-only**; other platforms will
reject unix socket requests.
- HTTPS tunneling uses BoringSSL via Rama's `rama-tls-boring`; building the proxy requires a
native toolchain and CMake on macOS/Linux/Windows.
## Security notes (important)
This section documents the protections implemented by `codex-network-proxy`, and the boundaries of
what it can reasonably guarantee.
- Allowlist-first policy: if `allowed_domains` is empty, requests are blocked until an allowlist is configured.
- Deny wins: entries in `denied_domains` always override the allowlist.
- Local/private network protection: when `allow_local_binding = false`, the proxy blocks loopback
and common private/link-local ranges. Explicit allowlisting of local IP literals (or `localhost`)
is required to permit them; hostnames that resolve to local/private IPs are still blocked even if
allowlisted (best-effort DNS lookup).
- Limited mode enforcement:
- only `GET`, `HEAD`, and `OPTIONS` are allowed
- HTTPS CONNECT is blocked unless MITM is enabled
- Listener safety defaults:
- the admin API is unauthenticated; non-loopback binds are clamped unless explicitly enabled via
`dangerously_allow_non_loopback_admin`
- the HTTP proxy listener similarly clamps non-loopback binds unless explicitly enabled via
`dangerously_allow_non_loopback_proxy`
- when unix socket proxying is enabled, both listeners are forced to loopback to avoid turning the
proxy into a remote bridge into local daemons.
- `enabled` is enforced at runtime; when false the proxy no-ops and does not bind listeners.
Limitations:
- DNS rebinding is hard to fully prevent without pinning the resolved IP(s) all the way down to the
transport layer. If your threat model includes hostile DNS, enforce network egress at a lower
layer too (e.g., firewall / VPC / corporate proxy policies).

View File

@@ -0,0 +1,157 @@
use crate::config::NetworkMode;
use crate::responses::json_response;
use crate::responses::text_response;
use crate::state::AppState;
use anyhow::Context;
use anyhow::Result;
use rama_core::rt::Executor;
use rama_core::service::service_fn;
use rama_http::Body;
use rama_http::Request;
use rama_http::Response;
use rama_http::StatusCode;
use rama_http_backend::server::HttpServer;
use rama_tcp::server::TcpListener;
use serde::Deserialize;
use serde::Serialize;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::error;
use tracing::info;
pub async fn run_admin_api(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
// Debug-only admin API (health/config/patterns/blocked + mode/reload). Policy is config-driven
// and constraint-enforced; this endpoint should not become a second policy/approval plane.
let listener = TcpListener::build()
.bind(addr)
.await
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
.map_err(rama_core::error::OpaqueError::from)
.map_err(anyhow::Error::from)
.with_context(|| format!("bind admin API: {addr}"))?;
let server_state = state.clone();
let server = HttpServer::auto(Executor::new()).service(service_fn(move |req| {
let state = server_state.clone();
async move { handle_admin_request(state, req).await }
}));
info!("admin API listening on {addr}");
listener.serve(server).await;
Ok(())
}
async fn handle_admin_request(state: Arc<AppState>, req: Request) -> Result<Response, Infallible> {
const MODE_BODY_LIMIT: usize = 8 * 1024;
let method = req.method().clone();
let path = req.uri().path().to_string();
let response = match (method.as_str(), path.as_str()) {
("GET", "/health") => Response::new(Body::from("ok")),
("GET", "/config") => match state.current_cfg().await {
Ok(cfg) => json_response(&cfg),
Err(err) => {
error!("failed to load config: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
("GET", "/patterns") => match state.current_patterns().await {
Ok((allow, deny)) => json_response(&PatternsResponse {
allowed: allow,
denied: deny,
}),
Err(err) => {
error!("failed to load patterns: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
("GET", "/blocked") => match state.drain_blocked().await {
Ok(blocked) => json_response(&BlockedResponse { blocked }),
Err(err) => {
error!("failed to read blocked queue: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
("POST", "/mode") => {
let mut body = req.into_body();
let mut buf: Vec<u8> = Vec::new();
loop {
let chunk = match body.chunk().await {
Ok(chunk) => chunk,
Err(err) => {
error!("failed to read mode body: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
}
};
let Some(chunk) = chunk else {
break;
};
if buf.len().saturating_add(chunk.len()) > MODE_BODY_LIMIT {
return Ok(text_response(
StatusCode::PAYLOAD_TOO_LARGE,
"body too large",
));
}
buf.extend_from_slice(&chunk);
}
if buf.is_empty() {
return Ok(text_response(StatusCode::BAD_REQUEST, "missing body"));
}
let update: ModeUpdate = match serde_json::from_slice(&buf) {
Ok(update) => update,
Err(err) => {
error!("failed to parse mode update: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid json"));
}
};
match state.set_network_mode(update.mode).await {
Ok(()) => json_response(&ModeUpdateResponse {
status: "ok",
mode: update.mode,
}),
Err(err) => {
error!("mode update failed: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "mode update failed")
}
}
}
("POST", "/reload") => match state.force_reload().await {
Ok(()) => json_response(&ReloadResponse { status: "reloaded" }),
Err(err) => {
error!("reload failed: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "reload failed")
}
},
_ => text_response(StatusCode::NOT_FOUND, "not found"),
};
Ok(response)
}
#[derive(Deserialize)]
struct ModeUpdate {
mode: NetworkMode,
}
#[derive(Debug, Serialize)]
struct PatternsResponse {
allowed: Vec<String>,
denied: Vec<String>,
}
#[derive(Debug, Serialize)]
struct BlockedResponse<T> {
blocked: T,
}
#[derive(Debug, Serialize)]
struct ModeUpdateResponse {
status: &'static str,
mode: NetworkMode,
}
#[derive(Debug, Serialize)]
struct ReloadResponse {
status: &'static str,
}

View File

@@ -0,0 +1,417 @@
use serde::Deserialize;
use serde::Serialize;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::path::PathBuf;
use tracing::warn;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
#[serde(default)]
pub network_proxy: NetworkProxyConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkProxyConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_proxy_url")]
pub proxy_url: String,
#[serde(default = "default_admin_url")]
pub admin_url: String,
#[serde(default)]
pub allow_upstream_proxy: bool,
#[serde(default)]
pub dangerously_allow_non_loopback_proxy: bool,
#[serde(default)]
pub dangerously_allow_non_loopback_admin: bool,
#[serde(default)]
pub mode: NetworkMode,
#[serde(default)]
pub policy: NetworkPolicy,
#[serde(default)]
pub mitm: MitmConfig,
}
impl Default for NetworkProxyConfig {
fn default() -> Self {
Self {
enabled: false,
proxy_url: default_proxy_url(),
admin_url: default_admin_url(),
allow_upstream_proxy: false,
dangerously_allow_non_loopback_proxy: false,
dangerously_allow_non_loopback_admin: false,
mode: NetworkMode::default(),
policy: NetworkPolicy::default(),
mitm: MitmConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct NetworkPolicy {
#[serde(default)]
pub allowed_domains: Vec<String>,
#[serde(default)]
pub denied_domains: Vec<String>,
#[serde(default)]
pub allow_unix_sockets: Vec<String>,
#[serde(default)]
pub allow_local_binding: bool,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum NetworkMode {
Limited,
#[default]
Full,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MitmConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub inspect: bool,
#[serde(default = "default_mitm_max_body_bytes")]
pub max_body_bytes: usize,
#[serde(default = "default_ca_cert_path")]
pub ca_cert_path: PathBuf,
#[serde(default = "default_ca_key_path")]
pub ca_key_path: PathBuf,
}
impl Default for MitmConfig {
fn default() -> Self {
Self {
enabled: false,
inspect: false,
max_body_bytes: default_mitm_max_body_bytes(),
ca_cert_path: default_ca_cert_path(),
ca_key_path: default_ca_key_path(),
}
}
}
fn default_proxy_url() -> String {
"http://127.0.0.1:3128".to_string()
}
fn default_admin_url() -> String {
"http://127.0.0.1:8080".to_string()
}
fn default_ca_cert_path() -> PathBuf {
PathBuf::from("network_proxy/mitm/ca.pem")
}
fn default_ca_key_path() -> PathBuf {
PathBuf::from("network_proxy/mitm/ca.key")
}
fn default_mitm_max_body_bytes() -> usize {
4096
}
fn clamp_non_loopback(addr: SocketAddr, allow_non_loopback: bool, name: &str) -> SocketAddr {
if addr.ip().is_loopback() {
return addr;
}
if allow_non_loopback {
warn!("DANGEROUS: {name} listening on non-loopback address {addr}");
return addr;
}
warn!(
"{name} requested non-loopback bind ({addr}); clamping to 127.0.0.1:{port} (set dangerously_allow_non_loopback_proxy or dangerously_allow_non_loopback_admin to override)",
port = addr.port()
);
SocketAddr::from(([127, 0, 0, 1], addr.port()))
}
pub(crate) fn clamp_bind_addrs(
http_addr: SocketAddr,
admin_addr: SocketAddr,
cfg: &NetworkProxyConfig,
) -> (SocketAddr, SocketAddr) {
let http_addr = clamp_non_loopback(
http_addr,
cfg.dangerously_allow_non_loopback_proxy,
"HTTP proxy",
);
let admin_addr = clamp_non_loopback(
admin_addr,
cfg.dangerously_allow_non_loopback_admin,
"admin API",
);
if cfg.policy.allow_unix_sockets.is_empty() {
return (http_addr, admin_addr);
}
// `x-unix-socket` is intentionally a local escape hatch. If the proxy (or admin API) is
// reachable from outside the machine, it can become a remote bridge into local daemons
// (e.g. docker.sock). To avoid footguns, enforce loopback binding whenever unix sockets
// are enabled.
if cfg.dangerously_allow_non_loopback_proxy && !http_addr.ip().is_loopback() {
warn!(
"unix socket proxying is enabled; ignoring dangerously_allow_non_loopback_proxy and clamping HTTP proxy to loopback"
);
}
if cfg.dangerously_allow_non_loopback_admin && !admin_addr.ip().is_loopback() {
warn!(
"unix socket proxying is enabled; ignoring dangerously_allow_non_loopback_admin and clamping admin API to loopback"
);
}
(
SocketAddr::from(([127, 0, 0, 1], http_addr.port())),
SocketAddr::from(([127, 0, 0, 1], admin_addr.port())),
)
}
pub struct RuntimeConfig {
pub http_addr: SocketAddr,
pub socks_addr: SocketAddr,
pub admin_addr: SocketAddr,
}
pub fn resolve_runtime(cfg: &Config) -> RuntimeConfig {
let http_addr = resolve_addr(&cfg.network_proxy.proxy_url, 3128);
let admin_addr = resolve_addr(&cfg.network_proxy.admin_url, 8080);
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg.network_proxy);
let socks_addr = SocketAddr::from(([127, 0, 0, 1], 8081));
RuntimeConfig {
http_addr,
socks_addr,
admin_addr,
}
}
fn resolve_addr(url: &str, default_port: u16) -> SocketAddr {
let addr_parts = parse_host_port(url, default_port);
let host = if addr_parts.host.eq_ignore_ascii_case("localhost") {
"127.0.0.1"
} else {
addr_parts.host
};
match host.parse::<IpAddr>() {
Ok(ip) => SocketAddr::new(ip, addr_parts.port),
Err(_) => SocketAddr::from(([127, 0, 0, 1], addr_parts.port)),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct SocketAddressParts<'a> {
host: &'a str,
port: u16,
}
fn parse_host_port(url: &str, default_port: u16) -> SocketAddressParts<'_> {
let trimmed = url.trim();
if trimmed.is_empty() {
return SocketAddressParts {
host: "127.0.0.1",
port: default_port,
};
}
let without_scheme = trimmed
.split_once("://")
.map(|(_, rest)| rest)
.unwrap_or(trimmed);
let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
let host_port = host_port
.rsplit_once('@')
.map(|(_, rest)| rest)
.unwrap_or(host_port);
if host_port.starts_with('[')
&& let Some(end) = host_port.find(']')
{
let host = &host_port[1..end];
let port = host_port[end + 1..]
.strip_prefix(':')
.and_then(|port| port.parse::<u16>().ok())
.unwrap_or(default_port);
return SocketAddressParts { host, port };
}
// Only treat `host:port` as such when there's a single `:`. This avoids
// accidentally interpreting unbracketed IPv6 addresses as `host:port`.
if host_port.bytes().filter(|b| *b == b':').count() == 1
&& let Some((host, port)) = host_port.rsplit_once(':')
&& let Ok(port) = port.parse::<u16>()
{
return SocketAddressParts { host, port };
}
SocketAddressParts {
host: host_port,
port: default_port,
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn parse_host_port_defaults_for_empty_string() {
assert_eq!(
parse_host_port("", 1234),
SocketAddressParts {
host: "127.0.0.1",
port: 1234,
}
);
}
#[test]
fn parse_host_port_defaults_for_whitespace() {
assert_eq!(
parse_host_port(" ", 5555),
SocketAddressParts {
host: "127.0.0.1",
port: 5555,
}
);
}
#[test]
fn parse_host_port_parses_host_port_without_scheme() {
assert_eq!(
parse_host_port("127.0.0.1:8080", 3128),
SocketAddressParts {
host: "127.0.0.1",
port: 8080,
}
);
}
#[test]
fn parse_host_port_parses_host_port_with_scheme_and_path() {
assert_eq!(
parse_host_port("http://example.com:8080/some/path", 3128),
SocketAddressParts {
host: "example.com",
port: 8080,
}
);
}
#[test]
fn parse_host_port_strips_userinfo() {
assert_eq!(
parse_host_port("http://user:pass@host.example:5555", 3128),
SocketAddressParts {
host: "host.example",
port: 5555,
}
);
}
#[test]
fn parse_host_port_parses_ipv6_with_brackets() {
assert_eq!(
parse_host_port("http://[::1]:9999", 3128),
SocketAddressParts {
host: "::1",
port: 9999,
}
);
}
#[test]
fn parse_host_port_does_not_treat_unbracketed_ipv6_as_host_port() {
assert_eq!(
parse_host_port("2001:db8::1", 3128),
SocketAddressParts {
host: "2001:db8::1",
port: 3128,
}
);
}
#[test]
fn parse_host_port_falls_back_to_default_port_when_port_is_invalid() {
assert_eq!(
parse_host_port("example.com:notaport", 3128),
SocketAddressParts {
host: "example.com:notaport",
port: 3128,
}
);
}
#[test]
fn resolve_addr_maps_localhost_to_loopback() {
assert_eq!(
resolve_addr("localhost", 3128),
"127.0.0.1:3128".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn resolve_addr_parses_ip_literals() {
assert_eq!(
resolve_addr("1.2.3.4", 80),
"1.2.3.4:80".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn resolve_addr_parses_ipv6_literals() {
assert_eq!(
resolve_addr("http://[::1]:8080", 3128),
"[::1]:8080".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn resolve_addr_falls_back_to_loopback_for_hostnames() {
assert_eq!(
resolve_addr("http://example.com:5555", 3128),
"127.0.0.1:5555".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn clamp_bind_addrs_allows_non_loopback_when_enabled() {
let cfg = NetworkProxyConfig {
dangerously_allow_non_loopback_proxy: true,
dangerously_allow_non_loopback_admin: true,
..Default::default()
};
let http_addr = "0.0.0.0:3128".parse::<SocketAddr>().unwrap();
let admin_addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg);
assert_eq!(http_addr, "0.0.0.0:3128".parse::<SocketAddr>().unwrap());
assert_eq!(admin_addr, "0.0.0.0:8080".parse::<SocketAddr>().unwrap());
}
#[test]
fn clamp_bind_addrs_forces_loopback_when_unix_sockets_enabled() {
let cfg = NetworkProxyConfig {
dangerously_allow_non_loopback_proxy: true,
dangerously_allow_non_loopback_admin: true,
policy: NetworkPolicy {
allow_unix_sockets: vec!["/tmp/docker.sock".to_string()],
..Default::default()
},
..Default::default()
};
let http_addr = "0.0.0.0:3128".parse::<SocketAddr>().unwrap();
let admin_addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg);
assert_eq!(http_addr, "127.0.0.1:3128".parse::<SocketAddr>().unwrap());
assert_eq!(admin_addr, "127.0.0.1:8080".parse::<SocketAddr>().unwrap());
}
}

View File

@@ -0,0 +1,626 @@
use crate::config::NetworkMode;
use crate::mitm;
use crate::network_policy::NetworkDecision;
use crate::network_policy::NetworkPolicyDecider;
use crate::network_policy::NetworkPolicyRequest;
use crate::network_policy::NetworkProtocol;
use crate::network_policy::evaluate_host_policy;
use crate::policy::normalize_host;
use crate::responses::blocked_header_value;
use crate::responses::json_response;
use crate::state::AppState;
use crate::state::BlockedRequest;
use crate::upstream::UpstreamClient;
use crate::upstream::proxy_for_connect;
use anyhow::Context as _;
use anyhow::Result;
use rama_core::Layer;
use rama_core::Service;
use rama_core::error::BoxError;
use rama_core::error::ErrorExt as _;
use rama_core::error::OpaqueError;
use rama_core::extensions::ExtensionsMut;
use rama_core::extensions::ExtensionsRef;
use rama_core::layer::AddInputExtensionLayer;
use rama_core::rt::Executor;
use rama_core::service::service_fn;
use rama_http::Body;
use rama_http::HeaderValue;
use rama_http::Request;
use rama_http::Response;
use rama_http::StatusCode;
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
use rama_http::matcher::MethodMatcher;
use rama_http_backend::client::proxy::layer::HttpProxyConnector;
use rama_http_backend::server::HttpServer;
use rama_http_backend::server::layer::upgrade::UpgradeLayer;
use rama_http_backend::server::layer::upgrade::Upgraded;
use rama_net::Protocol;
use rama_net::address::ProxyAddress;
use rama_net::client::ConnectorService;
use rama_net::client::EstablishedClientConnection;
use rama_net::http::RequestContext;
use rama_net::proxy::ProxyRequest;
use rama_net::proxy::ProxyTarget;
use rama_net::proxy::StreamForwardService;
use rama_net::stream::SocketInfo;
use rama_tcp::client::Request as TcpRequest;
use rama_tcp::client::service::TcpConnector;
use rama_tcp::server::TcpListener;
use rama_tls_boring::client::TlsConnectorDataBuilder;
use rama_tls_boring::client::TlsConnectorLayer;
use serde::Serialize;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::error;
use tracing::info;
use tracing::warn;
pub async fn run_http_proxy(
state: Arc<AppState>,
addr: SocketAddr,
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
) -> Result<()> {
let listener = TcpListener::build()
.bind(addr)
.await
// Rama's `BoxError` is a `Box<dyn Error + Send + Sync>` without an explicit `'static`
// lifetime bound, which means it doesn't satisfy `anyhow::Context`'s `StdError` constraint.
// Wrap it in Rama's `OpaqueError` so we can preserve the original error as a source and
// still use `anyhow` for chaining.
.map_err(rama_core::error::OpaqueError::from)
.map_err(anyhow::Error::from)
.with_context(|| format!("bind HTTP proxy: {addr}"))?;
let http_service = HttpServer::auto(Executor::new()).service(
(
UpgradeLayer::new(
MethodMatcher::CONNECT,
service_fn({
let policy_decider = policy_decider.clone();
move |req| http_connect_accept(policy_decider.clone(), req)
}),
service_fn(http_connect_proxy),
),
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
)
.into_layer(service_fn({
let policy_decider = policy_decider.clone();
move |req| http_plain_proxy(policy_decider.clone(), req)
})),
);
info!("HTTP proxy listening on {addr}");
listener
.serve(AddInputExtensionLayer::new(state).into_layer(http_service))
.await;
Ok(())
}
async fn http_connect_accept(
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
mut req: Request,
) -> Result<(Response, Request), Response> {
let app_state = req
.extensions()
.get::<Arc<AppState>>()
.cloned()
.ok_or_else(|| text_response(StatusCode::INTERNAL_SERVER_ERROR, "missing state"))?;
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
Ok(authority) => authority,
Err(err) => {
warn!("CONNECT missing authority: {err}");
return Err(text_response(StatusCode::BAD_REQUEST, "missing authority"));
}
};
let host = normalize_host(&authority.host.to_string());
if host.is_empty() {
return Err(text_response(StatusCode::BAD_REQUEST, "invalid host"));
}
let client = client_addr(&req);
let enabled = match app_state.enabled().await {
Ok(enabled) => enabled,
Err(err) => {
error!("failed to read enabled state: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
if !enabled {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"proxy_disabled".to_string(),
client.clone(),
Some("CONNECT".to_string()),
None,
"http-connect".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("CONNECT blocked; proxy disabled (client={client}, host={host})");
return Err(text_response(
StatusCode::SERVICE_UNAVAILABLE,
"proxy disabled",
));
}
let request = NetworkPolicyRequest::new(
NetworkProtocol::HttpsConnect,
host.clone(),
authority.port,
client.clone(),
Some("CONNECT".to_string()),
None,
None,
);
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
Ok(NetworkDecision::Deny { reason }) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
Some("CONNECT".to_string()),
None,
"http-connect".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("CONNECT blocked (client={client}, host={host}, reason={reason})");
return Err(blocked_text(&reason));
}
Ok(NetworkDecision::Allow) => {
let client = client.as_deref().unwrap_or_default();
info!("CONNECT allowed (client={client}, host={host})");
}
Err(err) => {
error!("failed to evaluate host for CONNECT {host}: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
let mode = match app_state.network_mode().await {
Ok(mode) => mode,
Err(err) => {
error!("failed to read network mode: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
let mitm_state = match app_state.mitm_state().await {
Ok(state) => state,
Err(err) => {
error!("failed to load MITM state: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
if mode == NetworkMode::Limited && mitm_state.is_none() {
// Limited mode is designed to be read-only. Without MITM, a CONNECT tunnel would hide the
// inner HTTP method/headers from the proxy, effectively bypassing method policy.
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"mitm_required".to_string(),
client.clone(),
Some("CONNECT".to_string()),
Some(NetworkMode::Limited),
"http-connect".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!(
"CONNECT blocked; MITM required for read-only HTTPS in limited mode (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Err(blocked_text("mitm_required"));
}
req.extensions_mut().insert(ProxyTarget(authority));
req.extensions_mut().insert(mode);
if let Some(mitm_state) = mitm_state {
req.extensions_mut().insert(mitm_state);
}
Ok((
Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty())),
req,
))
}
async fn http_connect_proxy(upgraded: Upgraded) -> Result<(), Infallible> {
let mode = upgraded
.extensions()
.get::<NetworkMode>()
.copied()
.unwrap_or(NetworkMode::Full);
let Some(target) = upgraded
.extensions()
.get::<ProxyTarget>()
.map(|t| t.0.clone())
else {
warn!("CONNECT missing proxy target");
return Ok(());
};
let host = normalize_host(&target.host.to_string());
if mode == NetworkMode::Limited
&& upgraded
.extensions()
.get::<Arc<mitm::MitmState>>()
.is_some()
{
let port = target.port;
info!("CONNECT MITM enabled (host={host}, port={port}, mode={mode:?})");
if let Err(err) = mitm::mitm_tunnel(upgraded).await {
warn!("MITM tunnel error: {err}");
}
return Ok(());
}
let allow_upstream_proxy = match upgraded.extensions().get::<Arc<AppState>>().cloned() {
Some(state) => match state.allow_upstream_proxy().await {
Ok(allowed) => allowed,
Err(err) => {
error!("failed to read upstream proxy setting: {err}");
false
}
},
None => {
error!("missing app state");
false
}
};
let proxy = if allow_upstream_proxy {
proxy_for_connect()
} else {
None
};
if let Err(err) = forward_connect_tunnel(upgraded, proxy).await {
warn!("tunnel error: {err}");
}
Ok(())
}
async fn forward_connect_tunnel(
upgraded: Upgraded,
proxy: Option<ProxyAddress>,
) -> Result<(), BoxError> {
let authority = upgraded
.extensions()
.get::<ProxyTarget>()
.map(|target| target.0.clone())
.ok_or_else(|| OpaqueError::from_display("missing forward authority").into_boxed())?;
let mut extensions = upgraded.extensions().clone();
if let Some(proxy) = proxy {
extensions.insert(proxy);
}
let req = TcpRequest::new_with_extensions(authority.clone(), extensions)
.with_protocol(Protocol::HTTPS);
let proxy_connector = HttpProxyConnector::optional(TcpConnector::new());
let tls_config = TlsConnectorDataBuilder::new_http_auto().into_shared_builder();
let connector = TlsConnectorLayer::tunnel(None)
.with_connector_data(tls_config)
.into_layer(proxy_connector);
let EstablishedClientConnection { conn: target, .. } =
connector.connect(req).await.map_err(|err| {
OpaqueError::from_boxed(err)
.with_context(|| format!("establish CONNECT tunnel to {authority}"))
.into_boxed()
})?;
let proxy_req = ProxyRequest {
source: upgraded,
target,
};
StreamForwardService::default()
.serve(proxy_req)
.await
.map_err(|err| {
OpaqueError::from_boxed(err.into())
.with_context(|| format!("forward CONNECT tunnel to {authority}"))
.into_boxed()
})
}
async fn http_plain_proxy(
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
req: Request,
) -> Result<Response, Infallible> {
let app_state = match req.extensions().get::<Arc<AppState>>().cloned() {
Some(state) => state,
None => {
error!("missing app state");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
let client = client_addr(&req);
let method_allowed = match app_state.method_allowed(req.method().as_str()).await {
Ok(allowed) => allowed,
Err(err) => {
error!("failed to evaluate method policy: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
// `x-unix-socket` is an escape hatch for talking to local daemons. We keep it tightly scoped:
// macOS-only + explicit allowlist, to avoid turning the proxy into a general local capability
// escalation mechanism.
if let Some(unix_socket_header) = req.headers().get("x-unix-socket") {
let socket_path = match unix_socket_header.to_str() {
Ok(value) => value.to_string(),
Err(_) => {
warn!("invalid x-unix-socket header value (non-UTF8)");
return Ok(text_response(
StatusCode::BAD_REQUEST,
"invalid x-unix-socket header",
));
}
};
let enabled = match app_state.enabled().await {
Ok(enabled) => enabled,
Err(err) => {
error!("failed to read enabled state: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
if !enabled {
let _ = app_state
.record_blocked(BlockedRequest::new(
socket_path.clone(),
"proxy_disabled".to_string(),
client.clone(),
Some(req.method().as_str().to_string()),
None,
"unix-socket".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("unix socket blocked; proxy disabled (client={client}, path={socket_path})");
return Ok(text_response(
StatusCode::SERVICE_UNAVAILABLE,
"proxy disabled",
));
}
if !method_allowed {
let client = client.as_deref().unwrap_or_default();
let method = req.method();
warn!(
"unix socket blocked by method policy (client={client}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Ok(json_blocked("unix-socket", "method_not_allowed"));
}
if !cfg!(target_os = "macos") {
warn!("unix socket proxy unsupported on this platform (path={socket_path})");
return Ok(text_response(
StatusCode::NOT_IMPLEMENTED,
"unix sockets unsupported",
));
}
match app_state.is_unix_socket_allowed(&socket_path).await {
Ok(true) => {
let client = client.as_deref().unwrap_or_default();
info!("unix socket allowed (client={client}, path={socket_path})");
match proxy_via_unix_socket(req, &socket_path).await {
Ok(resp) => return Ok(resp),
Err(err) => {
warn!("unix socket proxy failed: {err}");
return Ok(text_response(
StatusCode::BAD_GATEWAY,
"unix socket proxy failed",
));
}
}
}
Ok(false) => {
let client = client.as_deref().unwrap_or_default();
warn!("unix socket blocked (client={client}, path={socket_path})");
return Ok(json_blocked("unix-socket", "not_allowed"));
}
Err(err) => {
warn!("unix socket check failed: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
}
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
Ok(authority) => authority,
Err(err) => {
warn!("missing host: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
}
};
let host = normalize_host(&authority.host.to_string());
let port = authority.port;
let enabled = match app_state.enabled().await {
Ok(enabled) => enabled,
Err(err) => {
error!("failed to read enabled state: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
if !enabled {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"proxy_disabled".to_string(),
client.clone(),
Some(req.method().as_str().to_string()),
None,
"http".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
let method = req.method();
warn!("request blocked; proxy disabled (client={client}, host={host}, method={method})");
return Ok(text_response(
StatusCode::SERVICE_UNAVAILABLE,
"proxy disabled",
));
}
let request = NetworkPolicyRequest::new(
NetworkProtocol::Http,
host.clone(),
port,
client.clone(),
Some(req.method().as_str().to_string()),
None,
None,
);
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
Ok(NetworkDecision::Deny { reason }) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
Some(req.method().as_str().to_string()),
None,
"http".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("request blocked (client={client}, host={host}, reason={reason})");
return Ok(json_blocked(&host, &reason));
}
Ok(NetworkDecision::Allow) => {}
Err(err) => {
error!("failed to evaluate host for {host}: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
if !method_allowed {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"method_not_allowed".to_string(),
client.clone(),
Some(req.method().as_str().to_string()),
Some(NetworkMode::Limited),
"http".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
let method = req.method();
warn!(
"request blocked by method policy (client={client}, host={host}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Ok(json_blocked(&host, "method_not_allowed"));
}
let client = client.as_deref().unwrap_or_default();
let method = req.method();
info!("request allowed (client={client}, host={host}, method={method})");
let allow_upstream_proxy = match app_state.allow_upstream_proxy().await {
Ok(allow) => allow,
Err(err) => {
error!("failed to read upstream proxy config: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
let client = if allow_upstream_proxy {
UpstreamClient::from_env_proxy()
} else {
UpstreamClient::direct()
};
match client.serve(req).await {
Ok(resp) => Ok(resp),
Err(err) => {
warn!("upstream request failed: {err}");
Ok(text_response(StatusCode::BAD_GATEWAY, "upstream failure"))
}
}
}
async fn proxy_via_unix_socket(req: Request, socket_path: &str) -> Result<Response> {
#[cfg(target_os = "macos")]
{
let client = UpstreamClient::unix_socket(socket_path);
let (mut parts, body) = req.into_parts();
let path = parts
.uri
.path_and_query()
.map(rama_http::uri::PathAndQuery::as_str)
.unwrap_or("/");
parts.uri = path
.parse()
.with_context(|| format!("invalid unix socket request path: {path}"))?;
parts.headers.remove("x-unix-socket");
let req = Request::from_parts(parts, body);
client.serve(req).await.map_err(anyhow::Error::from)
}
#[cfg(not(target_os = "macos"))]
{
let _ = req;
let _ = socket_path;
Err(anyhow::anyhow!("unix sockets not supported"))
}
}
fn client_addr<T: ExtensionsRef>(input: &T) -> Option<String> {
input
.extensions()
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string())
}
fn json_blocked(host: &str, reason: &str) -> Response {
let response = BlockedResponse {
status: "blocked",
host,
reason,
};
let mut resp = json_response(&response);
*resp.status_mut() = StatusCode::FORBIDDEN;
resp.headers_mut().insert(
"x-proxy-error",
HeaderValue::from_static(blocked_header_value(reason)),
);
resp
}
fn blocked_text(reason: &str) -> Response {
crate::responses::blocked_text_response(reason)
}
fn text_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(Body::from(body.to_string()))
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}
#[derive(Serialize)]
struct BlockedResponse<'a> {
status: &'static str,
host: &'a str,
reason: &'a str,
}

View File

@@ -0,0 +1,17 @@
use anyhow::Context;
use anyhow::Result;
use codex_core::config::find_codex_home;
use std::fs;
pub fn run_init() -> Result<()> {
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
let root = codex_home.join("network_proxy");
let mitm_dir = root.join("mitm");
fs::create_dir_all(&root).with_context(|| format!("failed to create {}", root.display()))?;
fs::create_dir_all(&mitm_dir)
.with_context(|| format!("failed to create {}", mitm_dir.display()))?;
println!("ensured {}", mitm_dir.display());
Ok(())
}

View File

@@ -0,0 +1,35 @@
mod admin;
mod config;
mod http_proxy;
mod init;
mod mitm;
mod network_policy;
mod policy;
mod proxy;
mod responses;
mod runtime;
mod socks5;
mod state;
mod upstream;
use anyhow::Result;
pub use network_policy::NetworkDecision;
pub use network_policy::NetworkPolicyDecider;
pub use network_policy::NetworkPolicyRequest;
pub use network_policy::NetworkProtocol;
pub use proxy::Args;
pub use proxy::Command;
pub use proxy::NetworkProxy;
pub use proxy::NetworkProxyBuilder;
pub use proxy::NetworkProxyHandle;
pub use proxy::run_init;
pub async fn run_main(args: Args) -> Result<()> {
if let Some(Command::Init) = args.command {
run_init()?;
return Ok(());
}
let proxy = NetworkProxy::from_cli_args(args).await?;
proxy.run().await?.wait().await
}

View File

@@ -0,0 +1,13 @@
use anyhow::Result;
use clap::Parser;
use codex_network_proxy::Args;
use codex_network_proxy::NetworkProxy;
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let args = Args::parse();
let proxy = NetworkProxy::from_cli_args(args).await?;
proxy.run().await?.wait().await
}

View File

@@ -0,0 +1,622 @@
use crate::config::MitmConfig;
use crate::config::NetworkMode;
use crate::policy::method_allowed;
use crate::policy::normalize_host;
use crate::responses::blocked_text_response;
use crate::state::AppState;
use crate::state::BlockedRequest;
use crate::upstream::UpstreamClient;
use anyhow::Context as _;
use anyhow::Result;
use anyhow::anyhow;
use rama_core::Layer;
use rama_core::Service;
use rama_core::bytes::Bytes;
use rama_core::error::BoxError;
use rama_core::extensions::ExtensionsRef;
use rama_core::futures::stream::Stream;
use rama_core::rt::Executor;
use rama_core::service::service_fn;
use rama_http::Body;
use rama_http::BodyDataStream;
use rama_http::HeaderValue;
use rama_http::Request;
use rama_http::Response;
use rama_http::StatusCode;
use rama_http::Uri;
use rama_http::header::HOST;
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
use rama_http_backend::server::HttpServer;
use rama_http_backend::server::layer::upgrade::Upgraded;
use rama_net::proxy::ProxyTarget;
use rama_net::stream::SocketInfo;
use rama_net::tls::ApplicationProtocol;
use rama_net::tls::DataEncoding;
use rama_net::tls::server::ServerAuth;
use rama_net::tls::server::ServerAuthData;
use rama_net::tls::server::ServerConfig;
use rama_tls_boring::server::TlsAcceptorData;
use rama_tls_boring::server::TlsAcceptorLayer;
use rama_utils::str::NonEmptyStr;
use std::fs;
use std::fs::File;
use std::fs::OpenOptions;
use std::io::Write;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context as TaskContext;
use std::task::Poll;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use tracing::info;
use tracing::warn;
use rcgen_rama::BasicConstraints;
use rcgen_rama::CertificateParams;
use rcgen_rama::DistinguishedName;
use rcgen_rama::DnType;
use rcgen_rama::ExtendedKeyUsagePurpose;
use rcgen_rama::IsCa;
use rcgen_rama::Issuer;
use rcgen_rama::KeyPair;
use rcgen_rama::KeyUsagePurpose;
use rcgen_rama::SanType;
pub struct MitmState {
issuer: Issuer<'static, KeyPair>,
upstream: UpstreamClient,
inspect: bool,
max_body_bytes: usize,
}
impl std::fmt::Debug for MitmState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Avoid dumping internal state (CA material, connectors, etc.) to logs.
f.debug_struct("MitmState")
.field("inspect", &self.inspect)
.field("max_body_bytes", &self.max_body_bytes)
.finish_non_exhaustive()
}
}
impl MitmState {
pub fn new(cfg: &MitmConfig, allow_upstream_proxy: bool) -> Result<Self> {
// MITM exists to make limited-mode HTTPS enforceable: once CONNECT is established, plain
// proxying would lose visibility into the inner HTTP request. We generate/load a local CA
// and issue per-host leaf certs so we can terminate TLS and apply policy.
let (ca_cert_pem, ca_key_pem) = load_or_create_ca(cfg)?;
let ca_key = KeyPair::from_pem(&ca_key_pem).context("failed to parse CA key")?;
let issuer: Issuer<'static, KeyPair> =
Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key).context("failed to parse CA cert")?;
let upstream = if allow_upstream_proxy {
UpstreamClient::from_env_proxy()
} else {
UpstreamClient::direct()
};
Ok(Self {
issuer,
upstream,
inspect: cfg.inspect,
max_body_bytes: cfg.max_body_bytes,
})
}
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
let (cert_pem, key_pem) = issue_host_certificate_pem(host, &self.issuer)?;
let cert_chain = DataEncoding::Pem(
NonEmptyStr::try_from(cert_pem.as_str()).context("failed to encode host cert PEM")?,
);
let private_key = DataEncoding::Pem(
NonEmptyStr::try_from(key_pem.as_str()).context("failed to encode host key PEM")?,
);
let auth = ServerAuthData {
private_key,
cert_chain,
ocsp: None,
};
let mut server_config = ServerConfig::new(ServerAuth::Single(auth));
server_config.application_layer_protocol_negotiation = Some(vec![
ApplicationProtocol::HTTP_2,
ApplicationProtocol::HTTP_11,
]);
TlsAcceptorData::try_from(server_config).context("failed to build boring acceptor config")
}
pub fn inspect_enabled(&self) -> bool {
self.inspect
}
pub fn max_body_bytes(&self) -> usize {
self.max_body_bytes
}
}
pub async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
let state = upgraded
.extensions()
.get::<Arc<MitmState>>()
.cloned()
.context("missing MITM state")?;
let target = upgraded
.extensions()
.get::<ProxyTarget>()
.context("missing proxy target")?
.0
.clone();
let host = normalize_host(&target.host.to_string());
let acceptor_data = state.tls_acceptor_data_for_host(&host)?;
let executor = upgraded
.extensions()
.get::<Executor>()
.cloned()
.unwrap_or_default();
let http_service = HttpServer::auto(executor).service(
(
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
)
.into_layer(service_fn(handle_mitm_request)),
);
let https_service = TlsAcceptorLayer::new(acceptor_data)
.with_store_client_hello(true)
.into_layer(http_service);
https_service
.serve(upgraded)
.await
.map_err(|err| anyhow!("MITM serve error: {err}"))?;
Ok(())
}
async fn handle_mitm_request(req: Request) -> Result<Response, std::convert::Infallible> {
let response = match forward_request(req).await {
Ok(resp) => resp,
Err(err) => {
warn!("MITM upstream request failed: {err}");
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
}
};
Ok(response)
}
async fn forward_request(req: Request) -> Result<Response> {
let target = req
.extensions()
.get::<ProxyTarget>()
.context("missing proxy target")?
.0
.clone();
let target_host = normalize_host(&target.host.to_string());
let target_port = target.port;
let mode = req
.extensions()
.get::<NetworkMode>()
.copied()
.unwrap_or(NetworkMode::Full);
let mitm = req
.extensions()
.get::<Arc<MitmState>>()
.cloned()
.context("missing MITM state")?;
let app_state = req
.extensions()
.get::<Arc<AppState>>()
.cloned()
.context("missing app state")?;
if req.method().as_str() == "CONNECT" {
return Ok(text_response(
StatusCode::METHOD_NOT_ALLOWED,
"CONNECT not supported inside MITM",
));
}
let method = req.method().as_str().to_string();
let path = path_and_query(req.uri());
let client = req
.extensions()
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
if let Some(request_host) = extract_request_host(&req) {
let normalized = normalize_host(&request_host);
if !normalized.is_empty() && normalized != target_host {
warn!("MITM host mismatch (target={target_host}, request_host={normalized})");
return Ok(text_response(StatusCode::BAD_REQUEST, "host mismatch"));
}
}
if !method_allowed(mode, method.as_str()) {
let _ = app_state
.record_blocked(BlockedRequest::new(
target_host.clone(),
"method_not_allowed".to_string(),
client.clone(),
Some(method.clone()),
Some(NetworkMode::Limited),
"https".to_string(),
))
.await;
warn!(
"MITM blocked by method policy (host={target_host}, method={method}, path={path}, mode={mode:?}, allowed_methods=GET, HEAD, OPTIONS)"
);
return Ok(blocked_text("method_not_allowed"));
}
let (mut parts, body) = req.into_parts();
let authority = authority_header_value(&target_host, target_port);
parts.uri = build_https_uri(&authority, &path)?;
parts
.headers
.insert(HOST, HeaderValue::from_str(&authority)?);
let inspect = mitm.inspect_enabled();
let max_body_bytes = mitm.max_body_bytes();
let body = if inspect {
inspect_body(
body,
max_body_bytes,
RequestLogContext {
host: authority.clone(),
method: method.clone(),
path: path.clone(),
},
)
} else {
body
};
let upstream_req = Request::from_parts(parts, body);
let upstream_resp = mitm.upstream.serve(upstream_req).await?;
respond_with_inspection(
upstream_resp,
inspect,
max_body_bytes,
&method,
&path,
&authority,
)
}
fn respond_with_inspection(
resp: Response,
inspect: bool,
max_body_bytes: usize,
method: &str,
path: &str,
authority: &str,
) -> Result<Response> {
if !inspect {
return Ok(resp);
}
let (parts, body) = resp.into_parts();
let body = inspect_body(
body,
max_body_bytes,
ResponseLogContext {
host: authority.to_string(),
method: method.to_string(),
path: path.to_string(),
status: parts.status,
},
);
Ok(Response::from_parts(parts, body))
}
fn inspect_body<T: BodyLoggable + Send + 'static>(
body: Body,
max_body_bytes: usize,
ctx: T,
) -> Body {
Body::from_stream(InspectStream {
inner: Box::pin(body.into_data_stream()),
ctx: Some(Box::new(ctx)),
len: 0,
max_body_bytes,
})
}
struct InspectStream<T> {
inner: Pin<Box<BodyDataStream>>,
ctx: Option<Box<T>>,
len: usize,
max_body_bytes: usize,
}
impl<T: BodyLoggable> Stream for InspectStream<T> {
type Item = Result<Bytes, BoxError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.len = this.len.saturating_add(bytes.len());
Poll::Ready(Some(Ok(bytes)))
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => {
if let Some(ctx) = this.ctx.take() {
ctx.log(this.len, this.len > this.max_body_bytes);
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
struct RequestLogContext {
host: String,
method: String,
path: String,
}
struct ResponseLogContext {
host: String,
method: String,
path: String,
status: StatusCode,
}
trait BodyLoggable {
fn log(self, len: usize, truncated: bool);
}
impl BodyLoggable for RequestLogContext {
fn log(self, len: usize, truncated: bool) {
let host = self.host;
let method = self.method;
let path = self.path;
info!(
"MITM inspected request body (host={host}, method={method}, path={path}, body_len={len}, truncated={truncated})"
);
}
}
impl BodyLoggable for ResponseLogContext {
fn log(self, len: usize, truncated: bool) {
let host = self.host;
let method = self.method;
let path = self.path;
let status = self.status;
info!(
"MITM inspected response body (host={host}, method={method}, path={path}, status={status}, body_len={len}, truncated={truncated})"
);
}
}
fn extract_request_host(req: &Request) -> Option<String> {
req.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(ToString::to_string)
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
}
fn authority_header_value(host: &str, port: u16) -> String {
// Host header / URI authority formatting.
if host.contains(':') {
if port == 443 {
format!("[{host}]")
} else {
format!("[{host}]:{port}")
}
} else if port == 443 {
host.to_string()
} else {
format!("{host}:{port}")
}
}
fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
let target = format!("https://{authority}{path}");
Ok(target.parse()?)
}
fn path_and_query(uri: &Uri) -> String {
uri.path_and_query()
.map(rama_http::uri::PathAndQuery::as_str)
.unwrap_or("/")
.to_string()
}
fn issue_host_certificate_pem(
host: &str,
issuer: &Issuer<'_, KeyPair>,
) -> Result<(String, String)> {
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
let mut params = CertificateParams::new(Vec::new())
.map_err(|err| anyhow!("failed to create cert params: {err}"))?;
params.subject_alt_names.push(SanType::IpAddress(ip));
params
} else {
CertificateParams::new(vec![host.to_string()])
.map_err(|err| anyhow!("failed to create cert params: {err}"))?
};
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
let key_pair = KeyPair::generate_for(&rcgen_rama::PKCS_ECDSA_P256_SHA256)
.map_err(|err| anyhow!("failed to generate host key pair: {err}"))?;
let cert = params
.signed_by(&key_pair, issuer)
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
Ok((cert.pem(), key_pair.serialize_pem()))
}
fn load_or_create_ca(cfg: &MitmConfig) -> Result<(String, String)> {
let cert_path = &cfg.ca_cert_path;
let key_path = &cfg.ca_key_path;
if cert_path.exists() || key_path.exists() {
if !cert_path.exists() || !key_path.exists() {
return Err(anyhow!("both ca_cert_path and ca_key_path must exist"));
}
let cert_pem = fs::read_to_string(cert_path)
.with_context(|| format!("failed to read CA cert {}", cert_path.display()))?;
let key_pem = fs::read_to_string(key_path)
.with_context(|| format!("failed to read CA key {}", key_path.display()))?;
return Ok((cert_pem, key_pem));
}
if let Some(parent) = cert_path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create {}", parent.display()))?;
}
if let Some(parent) = key_path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create {}", parent.display()))?;
}
let (cert_pem, key_pem) = generate_ca()?;
// The CA key is a high-value secret. Create it atomically with restrictive permissions.
// The cert can be world-readable, but we still write it atomically to avoid partial writes.
//
// We intentionally use create-new semantics: if a key already exists, we should not overwrite
// it silently (that would invalidate previously-trusted cert chains).
write_atomic_create_new(key_path, key_pem.as_bytes(), 0o600)
.with_context(|| format!("failed to persist CA key {}", key_path.display()))?;
if let Err(err) = write_atomic_create_new(cert_path, cert_pem.as_bytes(), 0o644)
.with_context(|| format!("failed to persist CA cert {}", cert_path.display()))
{
// Avoid leaving a partially-created CA around (cert missing) if the second write fails.
let _ = fs::remove_file(key_path);
return Err(err);
}
let cert_path = cert_path.display();
let key_path = key_path.display();
info!("generated MITM CA (cert_path={cert_path}, key_path={key_path})");
Ok((cert_pem, key_pem))
}
fn generate_ca() -> Result<(String, String)> {
let mut params = CertificateParams::default();
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "network_proxy MITM CA");
params.distinguished_name = dn;
let key_pair = KeyPair::generate_for(&rcgen_rama::PKCS_ECDSA_P256_SHA256)
.map_err(|err| anyhow!("failed to generate CA key pair: {err}"))?;
let cert = params
.self_signed(&key_pair)
.map_err(|err| anyhow!("failed to generate CA cert: {err}"))?;
Ok((cert.pem(), key_pair.serialize_pem()))
}
fn write_atomic_create_new(path: &std::path::Path, contents: &[u8], mode: u32) -> Result<()> {
let parent = path
.parent()
.ok_or_else(|| anyhow!("missing parent directory"))?;
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let pid = std::process::id();
let file_name = path.file_name().unwrap_or_default().to_string_lossy();
let tmp_path = parent.join(format!(".{file_name}.tmp.{pid}.{nanos}"));
let mut file = open_create_new_with_mode(&tmp_path, mode)?;
file.write_all(contents)
.with_context(|| format!("failed to write {}", tmp_path.display()))?;
file.sync_all()
.with_context(|| format!("failed to fsync {}", tmp_path.display()))?;
drop(file);
// Create the final file using "create-new" semantics (no overwrite). `rename` on Unix can
// overwrite existing files, so prefer a hard-link, which fails if the destination exists.
match fs::hard_link(&tmp_path, path) {
Ok(()) => {
fs::remove_file(&tmp_path)
.with_context(|| format!("failed to remove {}", tmp_path.display()))?;
}
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
let _ = fs::remove_file(&tmp_path);
return Err(anyhow!(
"refusing to overwrite existing file {}",
path.display()
));
}
Err(_) => {
// Best-effort fallback for environments where hard links are not supported.
// This is still subject to a TOCTOU race, but the typical case is a private per-user
// config directory, where other users cannot create files anyway.
if path.exists() {
let _ = fs::remove_file(&tmp_path);
return Err(anyhow!(
"refusing to overwrite existing file {}",
path.display()
));
}
fs::rename(&tmp_path, path).with_context(|| {
format!(
"failed to rename {} -> {}",
tmp_path.display(),
path.display()
)
})?;
}
}
// Best-effort durability: ensure the directory entry is persisted too.
let dir = File::open(parent).with_context(|| format!("failed to open {}", parent.display()))?;
dir.sync_all()
.with_context(|| format!("failed to fsync {}", parent.display()))?;
Ok(())
}
#[cfg(unix)]
fn open_create_new_with_mode(path: &std::path::Path, mode: u32) -> Result<File> {
use std::os::unix::fs::OpenOptionsExt;
OpenOptions::new()
.write(true)
.create_new(true)
.mode(mode)
.open(path)
.with_context(|| format!("failed to create {}", path.display()))
}
#[cfg(not(unix))]
fn open_create_new_with_mode(path: &std::path::Path, _mode: u32) -> Result<File> {
OpenOptions::new()
.write(true)
.create_new(true)
.open(path)
.with_context(|| format!("failed to create {}", path.display()))
}
fn blocked_text(reason: &str) -> Response {
blocked_text_response(reason)
}
fn text_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(Body::from(body.to_string()))
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}

View File

@@ -0,0 +1,230 @@
use crate::state::AppState;
use anyhow::Result;
use async_trait::async_trait;
use std::future::Future;
use std::sync::Arc;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NetworkProtocol {
Http,
HttpsConnect,
Socks5Tcp,
Socks5Udp,
}
#[derive(Clone, Debug)]
pub struct NetworkPolicyRequest {
pub protocol: NetworkProtocol,
pub host: String,
pub port: u16,
pub client_addr: Option<String>,
pub method: Option<String>,
pub command: Option<String>,
pub exec_policy_hint: Option<String>,
}
impl NetworkPolicyRequest {
#[must_use]
pub fn new(
protocol: NetworkProtocol,
host: String,
port: u16,
client_addr: Option<String>,
method: Option<String>,
command: Option<String>,
exec_policy_hint: Option<String>,
) -> Self {
Self {
protocol,
host,
port,
client_addr,
method,
command,
exec_policy_hint,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum NetworkDecision {
Allow,
Deny { reason: String },
}
impl NetworkDecision {
#[must_use]
pub fn deny(reason: impl Into<String>) -> Self {
let reason = reason.into();
let reason = if reason.is_empty() {
"policy_denied".to_string()
} else {
reason
};
Self::Deny { reason }
}
}
/// Decide whether a network request should be allowed.
///
/// If `command` or `exec_policy_hint` is provided, callers can map exec-policy
/// approvals to network access (e.g., allow all requests for commands matching
/// approved prefixes like `curl *`).
#[async_trait]
pub trait NetworkPolicyDecider: Send + Sync + 'static {
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision;
}
#[async_trait]
impl<D: NetworkPolicyDecider + ?Sized> NetworkPolicyDecider for Arc<D> {
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision {
(**self).decide(req).await
}
}
#[async_trait]
impl<F, Fut> NetworkPolicyDecider for F
where
F: Fn(NetworkPolicyRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = NetworkDecision> + Send,
{
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision {
(self)(req).await
}
}
pub(crate) async fn evaluate_host_policy(
state: &AppState,
decider: Option<&Arc<dyn NetworkPolicyDecider>>,
request: &NetworkPolicyRequest,
) -> Result<NetworkDecision> {
let (blocked, reason) = state.host_blocked(&request.host, request.port).await?;
if !blocked {
return Ok(NetworkDecision::Allow);
}
if reason == "not_allowed"
&& let Some(decider) = decider
{
return Ok(decider.decide(request.clone()).await);
}
Ok(NetworkDecision::deny(reason))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::NetworkPolicy;
use crate::state::app_state_for_policy;
use pretty_assertions::assert_eq;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
#[tokio::test]
async fn evaluate_host_policy_invokes_decider_for_not_allowed() {
let state = app_state_for_policy(NetworkPolicy::default());
let calls = Arc::new(AtomicUsize::new(0));
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
let calls = calls.clone();
move |_req| {
calls.fetch_add(1, Ordering::SeqCst);
async { NetworkDecision::Allow }
}
});
let request = NetworkPolicyRequest::new(
NetworkProtocol::Http,
"example.com".to_string(),
80,
None,
Some("GET".to_string()),
None,
None,
);
let decision = evaluate_host_policy(&state, Some(&decider), &request)
.await
.unwrap();
assert_eq!(decision, NetworkDecision::Allow);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn evaluate_host_policy_skips_decider_for_denied() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
denied_domains: vec!["blocked.com".to_string()],
..NetworkPolicy::default()
});
let calls = Arc::new(AtomicUsize::new(0));
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
let calls = calls.clone();
move |_req| {
calls.fetch_add(1, Ordering::SeqCst);
async { NetworkDecision::Allow }
}
});
let request = NetworkPolicyRequest::new(
NetworkProtocol::Http,
"blocked.com".to_string(),
80,
None,
Some("GET".to_string()),
None,
None,
);
let decision = evaluate_host_policy(&state, Some(&decider), &request)
.await
.unwrap();
assert_eq!(
decision,
NetworkDecision::Deny {
reason: "denied".to_string()
}
);
assert_eq!(calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn evaluate_host_policy_skips_decider_for_not_allowed_local() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
let calls = Arc::new(AtomicUsize::new(0));
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
let calls = calls.clone();
move |_req| {
calls.fetch_add(1, Ordering::SeqCst);
async { NetworkDecision::Allow }
}
});
let request = NetworkPolicyRequest::new(
NetworkProtocol::Http,
"127.0.0.1".to_string(),
80,
None,
Some("GET".to_string()),
None,
None,
);
let decision = evaluate_host_policy(&state, Some(&decider), &request)
.await
.unwrap();
assert_eq!(
decision,
NetworkDecision::Deny {
reason: "not_allowed_local".to_string()
}
);
assert_eq!(calls.load(Ordering::SeqCst), 0);
}
}

View File

@@ -0,0 +1,338 @@
use crate::config::NetworkMode;
use anyhow::Context;
use anyhow::Result;
use globset::GlobBuilder;
use globset::GlobSet;
use globset::GlobSetBuilder;
use std::collections::HashSet;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
pub fn method_allowed(mode: NetworkMode, method: &str) -> bool {
match mode {
NetworkMode::Full => true,
NetworkMode::Limited => matches!(method, "GET" | "HEAD" | "OPTIONS"),
}
}
pub fn is_loopback_host(host: &str) -> bool {
let host = host.to_ascii_lowercase();
if host == "localhost" || host == "localhost." {
return true;
}
if let Ok(ip) = host.parse::<IpAddr>() {
return ip.is_loopback();
}
false
}
pub fn is_non_public_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(ip) => is_non_public_ipv4(ip),
IpAddr::V6(ip) => is_non_public_ipv6(ip),
}
}
fn is_non_public_ipv4(ip: Ipv4Addr) -> bool {
// Use the standard library classification helpers where possible; they encode the intent more
// clearly than hand-rolled range checks.
ip.is_loopback()
|| ip.is_private()
|| ip.is_link_local()
|| ip.is_unspecified()
|| ip.is_multicast()
}
fn is_non_public_ipv6(ip: Ipv6Addr) -> bool {
if let Some(v4) = ip.to_ipv4() {
return is_non_public_ipv4(v4) || ip.is_loopback();
}
// Treat anything that isn't globally routable as "local" for SSRF prevention. In particular:
// - `::1` loopback
// - `fc00::/7` unique-local (RFC 4193)
// - `fe80::/10` link-local
// - `::` unspecified
// - multicast ranges
ip.is_loopback()
|| ip.is_unspecified()
|| ip.is_multicast()
|| ip.is_unique_local()
|| ip.is_unicast_link_local()
}
pub fn normalize_host(host: &str) -> String {
let host = host.trim();
if host.starts_with('[')
&& let Some(end) = host.find(']')
{
return normalize_dns_host(&host[1..end]);
}
// The proxy stack should typically hand us a host without a port, but be
// defensive and strip `:port` when there is exactly one `:`.
if host.bytes().filter(|b| *b == b':').count() == 1 {
let host = host.split(':').next().unwrap_or_default();
return normalize_dns_host(host);
}
// Avoid mangling unbracketed IPv6 literals, but strip trailing dots so fully qualified domain
// names are treated the same as their dotless variants.
normalize_dns_host(host)
}
fn normalize_dns_host(host: &str) -> String {
let host = host.to_ascii_lowercase();
host.trim_end_matches('.').to_string()
}
fn normalize_pattern(pattern: &str) -> String {
let pattern = pattern.trim();
if pattern == "*" {
return "*".to_string();
}
let (prefix, remainder) = if let Some(domain) = pattern.strip_prefix("**.") {
("**.", domain)
} else if let Some(domain) = pattern.strip_prefix("*.") {
("*.", domain)
} else {
("", pattern)
};
let remainder = normalize_host(remainder);
if prefix.is_empty() {
remainder
} else {
format!("{prefix}{remainder}")
}
}
pub(crate) fn compile_globset(patterns: &[String]) -> Result<GlobSet> {
let mut builder = GlobSetBuilder::new();
let mut seen = HashSet::new();
for pattern in patterns {
let pattern = normalize_pattern(pattern);
// Supported domain patterns:
// - "example.com": match the exact host
// - "*.example.com": match any subdomain (not the apex)
// - "**.example.com": match the apex and any subdomain
// - "*": match any host
for candidate in expand_domain_pattern(&pattern) {
if !seen.insert(candidate.clone()) {
continue;
}
let glob = GlobBuilder::new(&candidate)
.case_insensitive(true)
.build()
.with_context(|| format!("invalid glob pattern: {candidate}"))?;
builder.add(glob);
}
}
Ok(builder.build()?)
}
#[derive(Debug, Clone)]
pub(crate) enum DomainPattern {
Any,
ApexAndSubdomains(String),
SubdomainsOnly(String),
Exact(String),
}
impl DomainPattern {
pub(crate) fn parse(input: &str) -> Self {
if input == "*" {
Self::Any
} else if let Some(domain) = input.strip_prefix("**.") {
Self::ApexAndSubdomains(domain.to_string())
} else if let Some(domain) = input.strip_prefix("*.") {
Self::SubdomainsOnly(domain.to_string())
} else {
Self::Exact(input.to_string())
}
}
pub(crate) fn allows(&self, candidate: &DomainPattern) -> bool {
match self {
DomainPattern::Any => true,
DomainPattern::Exact(domain) => match candidate {
DomainPattern::Exact(candidate) => domain_eq(candidate, domain),
_ => false,
},
DomainPattern::SubdomainsOnly(domain) => match candidate {
DomainPattern::Any => false,
DomainPattern::Exact(candidate) => is_strict_subdomain(candidate, domain),
DomainPattern::SubdomainsOnly(candidate) => {
is_subdomain_or_equal(candidate, domain)
}
DomainPattern::ApexAndSubdomains(candidate) => {
is_strict_subdomain(candidate, domain)
}
},
DomainPattern::ApexAndSubdomains(domain) => match candidate {
DomainPattern::Any => false,
DomainPattern::Exact(candidate) => is_subdomain_or_equal(candidate, domain),
DomainPattern::SubdomainsOnly(candidate) => {
is_subdomain_or_equal(candidate, domain)
}
DomainPattern::ApexAndSubdomains(candidate) => {
is_subdomain_or_equal(candidate, domain)
}
},
}
}
}
fn expand_domain_pattern(pattern: &str) -> Vec<String> {
match DomainPattern::parse(pattern) {
DomainPattern::Any => vec![pattern.to_string()],
DomainPattern::Exact(domain) => vec![domain],
DomainPattern::SubdomainsOnly(domain) => {
vec![format!("?*.{domain}")]
}
DomainPattern::ApexAndSubdomains(domain) => {
vec![domain.clone(), format!("?*.{domain}")]
}
}
}
fn normalize_domain(domain: &str) -> String {
domain.trim_end_matches('.').to_ascii_lowercase()
}
fn domain_eq(left: &str, right: &str) -> bool {
normalize_domain(left) == normalize_domain(right)
}
fn is_subdomain_or_equal(child: &str, parent: &str) -> bool {
let child = normalize_domain(child);
let parent = normalize_domain(parent);
if child == parent {
return true;
}
child.ends_with(&format!(".{parent}"))
}
fn is_strict_subdomain(child: &str, parent: &str) -> bool {
let child = normalize_domain(child);
let parent = normalize_domain(parent);
child != parent && child.ends_with(&format!(".{parent}"))
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn method_allowed_full_allows_everything() {
assert!(method_allowed(NetworkMode::Full, "GET"));
assert!(method_allowed(NetworkMode::Full, "POST"));
assert!(method_allowed(NetworkMode::Full, "CONNECT"));
}
#[test]
fn method_allowed_limited_allows_only_safe_methods() {
assert!(method_allowed(NetworkMode::Limited, "GET"));
assert!(method_allowed(NetworkMode::Limited, "HEAD"));
assert!(method_allowed(NetworkMode::Limited, "OPTIONS"));
assert!(!method_allowed(NetworkMode::Limited, "POST"));
assert!(!method_allowed(NetworkMode::Limited, "CONNECT"));
}
#[test]
fn compile_globset_normalizes_trailing_dots() {
let set = compile_globset(&["Example.COM.".to_string()]).unwrap();
assert_eq!(true, set.is_match("example.com"));
assert_eq!(false, set.is_match("api.example.com"));
}
#[test]
fn compile_globset_normalizes_wildcards() {
let set = compile_globset(&["*.Example.COM.".to_string()]).unwrap();
assert_eq!(true, set.is_match("api.example.com"));
assert_eq!(false, set.is_match("example.com"));
}
#[test]
fn compile_globset_normalizes_apex_and_subdomains() {
let set = compile_globset(&["**.Example.COM.".to_string()]).unwrap();
assert_eq!(true, set.is_match("example.com"));
assert_eq!(true, set.is_match("api.example.com"));
}
#[test]
fn compile_globset_normalizes_bracketed_ipv6_literals() {
let set = compile_globset(&["[::1]".to_string()]).unwrap();
assert_eq!(true, set.is_match("::1"));
}
#[test]
fn is_loopback_host_handles_localhost_variants() {
assert!(is_loopback_host("localhost"));
assert!(is_loopback_host("localhost."));
assert!(is_loopback_host("LOCALHOST"));
assert!(!is_loopback_host("notlocalhost"));
}
#[test]
fn is_loopback_host_handles_ip_literals() {
assert!(is_loopback_host("127.0.0.1"));
assert!(is_loopback_host("::1"));
assert!(!is_loopback_host("1.2.3.4"));
}
#[test]
fn is_non_public_ip_rejects_private_and_loopback_ranges() {
assert!(is_non_public_ip("127.0.0.1".parse().unwrap()));
assert!(is_non_public_ip("10.0.0.1".parse().unwrap()));
assert!(is_non_public_ip("192.168.0.1".parse().unwrap()));
assert!(!is_non_public_ip("8.8.8.8".parse().unwrap()));
assert!(is_non_public_ip("::ffff:127.0.0.1".parse().unwrap()));
assert!(is_non_public_ip("::ffff:10.0.0.1".parse().unwrap()));
assert!(!is_non_public_ip("::ffff:8.8.8.8".parse().unwrap()));
assert!(is_non_public_ip("::1".parse().unwrap()));
assert!(is_non_public_ip("fe80::1".parse().unwrap()));
assert!(is_non_public_ip("fc00::1".parse().unwrap()));
}
#[test]
fn normalize_host_lowercases_and_trims() {
assert_eq!(normalize_host(" ExAmPlE.CoM "), "example.com");
}
#[test]
fn normalize_host_strips_port_for_host_port() {
assert_eq!(normalize_host("example.com:1234"), "example.com");
}
#[test]
fn normalize_host_preserves_unbracketed_ipv6() {
assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
}
#[test]
fn normalize_host_strips_trailing_dot() {
assert_eq!(normalize_host("example.com."), "example.com");
assert_eq!(normalize_host("ExAmPlE.CoM."), "example.com");
}
#[test]
fn normalize_host_strips_trailing_dot_with_port() {
assert_eq!(normalize_host("example.com.:443"), "example.com");
}
#[test]
fn normalize_host_strips_brackets_for_ipv6() {
assert_eq!(normalize_host("[::1]"), "::1");
assert_eq!(normalize_host("[::1]:443"), "::1");
}
}

View File

@@ -0,0 +1,202 @@
use crate::admin;
use crate::config;
use crate::http_proxy;
use crate::init;
use crate::network_policy::NetworkPolicyDecider;
use crate::socks5;
use crate::state::AppState;
use anyhow::Result;
use clap::Parser;
use clap::Subcommand;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::task::JoinHandle;
use tracing::warn;
#[derive(Debug, Clone, Parser)]
#[command(name = "codex-network-proxy", about = "Codex network sandbox proxy")]
pub struct Args {
#[command(subcommand)]
pub command: Option<Command>,
/// Enable SOCKS5 UDP associate support (default: disabled).
#[arg(long, default_value_t = false)]
pub enable_socks5_udp: bool,
}
#[derive(Debug, Clone, Subcommand)]
pub enum Command {
/// Initialize the Codex network proxy directories (e.g. MITM cert paths).
Init,
}
#[derive(Clone, Default)]
pub struct NetworkProxyBuilder {
state: Option<Arc<AppState>>,
http_addr: Option<SocketAddr>,
socks_addr: Option<SocketAddr>,
admin_addr: Option<SocketAddr>,
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
enable_socks5_udp: bool,
}
impl NetworkProxyBuilder {
#[must_use]
pub fn state(mut self, state: Arc<AppState>) -> Self {
self.state = Some(state);
self
}
#[must_use]
pub fn http_addr(mut self, addr: SocketAddr) -> Self {
self.http_addr = Some(addr);
self
}
#[must_use]
pub fn socks_addr(mut self, addr: SocketAddr) -> Self {
self.socks_addr = Some(addr);
self
}
#[must_use]
pub fn admin_addr(mut self, addr: SocketAddr) -> Self {
self.admin_addr = Some(addr);
self
}
#[must_use]
pub fn policy_decider<D>(mut self, decider: D) -> Self
where
D: NetworkPolicyDecider,
{
self.policy_decider = Some(Arc::new(decider));
self
}
#[must_use]
pub fn policy_decider_arc(mut self, decider: Arc<dyn NetworkPolicyDecider>) -> Self {
self.policy_decider = Some(decider);
self
}
#[must_use]
pub fn enable_socks5_udp(mut self, enabled: bool) -> Self {
self.enable_socks5_udp = enabled;
self
}
pub async fn build(self) -> Result<NetworkProxy> {
let state = match self.state {
Some(state) => state,
None => Arc::new(AppState::new().await?),
};
let runtime = config::resolve_runtime(&state.current_cfg().await?);
let current_cfg = state.current_cfg().await?;
// Reapply bind clamping for caller overrides so unix-socket proxying stays loopback-only.
let (http_addr, admin_addr) = config::clamp_bind_addrs(
self.http_addr.unwrap_or(runtime.http_addr),
self.admin_addr.unwrap_or(runtime.admin_addr),
&current_cfg.network_proxy,
);
Ok(NetworkProxy {
state,
http_addr,
socks_addr: self.socks_addr.unwrap_or(runtime.socks_addr),
admin_addr,
policy_decider: self.policy_decider,
enable_socks5_udp: self.enable_socks5_udp,
})
}
}
#[derive(Clone)]
pub struct NetworkProxy {
state: Arc<AppState>,
http_addr: SocketAddr,
socks_addr: SocketAddr,
admin_addr: SocketAddr,
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
enable_socks5_udp: bool,
}
impl NetworkProxy {
#[must_use]
pub fn builder() -> NetworkProxyBuilder {
NetworkProxyBuilder::default()
}
pub async fn from_cli_args(args: Args) -> Result<Self> {
let mut builder = Self::builder();
builder = builder.enable_socks5_udp(args.enable_socks5_udp);
builder.build().await
}
pub async fn run(&self) -> Result<NetworkProxyHandle> {
let current_cfg = self.state.current_cfg().await?;
if !current_cfg.network_proxy.enabled {
warn!("network_proxy.enabled is false; skipping proxy listeners");
return Ok(NetworkProxyHandle::noop());
}
if cfg!(not(target_os = "macos")) {
warn!("allowUnixSockets is macOS-only; requests will be rejected on this platform");
}
let http_task = tokio::spawn(http_proxy::run_http_proxy(
self.state.clone(),
self.http_addr,
self.policy_decider.clone(),
));
let socks_task = tokio::spawn(socks5::run_socks5(
self.state.clone(),
self.socks_addr,
self.policy_decider.clone(),
self.enable_socks5_udp,
));
let admin_task = tokio::spawn(admin::run_admin_api(self.state.clone(), self.admin_addr));
Ok(NetworkProxyHandle {
http_task,
socks_task,
admin_task,
})
}
}
pub struct NetworkProxyHandle {
http_task: JoinHandle<Result<()>>,
socks_task: JoinHandle<Result<()>>,
admin_task: JoinHandle<Result<()>>,
}
impl NetworkProxyHandle {
fn noop() -> Self {
Self {
http_task: tokio::spawn(async { Ok(()) }),
socks_task: tokio::spawn(async { Ok(()) }),
admin_task: tokio::spawn(async { Ok(()) }),
}
}
pub async fn wait(self) -> Result<()> {
self.http_task.await??;
self.socks_task.await??;
self.admin_task.await??;
Ok(())
}
pub async fn shutdown(self) -> Result<()> {
self.http_task.abort();
self.socks_task.abort();
self.admin_task.abort();
let _ = self.http_task.await;
let _ = self.socks_task.await;
let _ = self.admin_task.await;
Ok(())
}
}
pub fn run_init() -> Result<()> {
init::run_init()
}

View File

@@ -0,0 +1,54 @@
use rama_http::Body;
use rama_http::Response;
use rama_http::StatusCode;
use serde::Serialize;
pub fn text_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(Body::from(body.to_string()))
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}
pub fn json_response<T: Serialize>(value: &T) -> Response {
let body = match serde_json::to_string(value) {
Ok(body) => body,
Err(_) => "{}".to_string(),
};
Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap_or_else(|_| Response::new(Body::from("{}")))
}
pub fn blocked_header_value(reason: &str) -> &'static str {
match reason {
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
"denied" => "blocked-by-denylist",
"method_not_allowed" => "blocked-by-method-policy",
"mitm_required" => "blocked-by-mitm-required",
_ => "blocked-by-policy",
}
}
pub fn blocked_message(reason: &str) -> &'static str {
match reason {
"not_allowed" => "Codex blocked this request: domain not in allowlist.",
"not_allowed_local" => "Codex blocked this request: local/private addresses not allowed.",
"denied" => "Codex blocked this request: domain denied by policy.",
"method_not_allowed" => "Codex blocked this request: method not allowed in limited mode.",
"mitm_required" => "Codex blocked this request: MITM required for limited HTTPS.",
_ => "Codex blocked this request by network policy.",
}
}
pub fn blocked_text_response(reason: &str) -> Response {
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "text/plain")
.header("x-proxy-error", blocked_header_value(reason))
.body(Body::from(blocked_message(reason)))
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}

View File

@@ -0,0 +1,909 @@
use crate::config::Config;
use crate::config::NetworkMode;
use crate::mitm::MitmState;
use crate::policy::is_loopback_host;
use crate::policy::is_non_public_ip;
use crate::policy::method_allowed;
use crate::policy::normalize_host;
use crate::state::NetworkProxyConstraints;
use crate::state::build_config_state;
use crate::state::validate_policy_against_constraints;
use anyhow::Context;
use anyhow::Result;
use globset::GlobSet;
use serde::Serialize;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::net::IpAddr;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use time::OffsetDateTime;
use tokio::net::lookup_host;
use tokio::sync::RwLock;
use tokio::time::timeout;
use tracing::info;
use tracing::warn;
const MAX_BLOCKED_EVENTS: usize = 200;
const DNS_LOOKUP_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(Clone, Debug, Serialize)]
pub struct BlockedRequest {
pub host: String,
pub reason: String,
pub client: Option<String>,
pub method: Option<String>,
pub mode: Option<NetworkMode>,
pub protocol: String,
pub timestamp: i64,
}
impl BlockedRequest {
pub fn new(
host: String,
reason: String,
client: Option<String>,
method: Option<String>,
mode: Option<NetworkMode>,
protocol: String,
) -> Self {
Self {
host,
reason,
client,
method,
mode,
protocol,
timestamp: unix_timestamp(),
}
}
}
#[derive(Clone)]
pub(crate) struct ConfigState {
pub(crate) config: Config,
pub(crate) mtime: Option<SystemTime>,
pub(crate) allow_set: GlobSet,
pub(crate) deny_set: GlobSet,
pub(crate) mitm: Option<Arc<MitmState>>,
pub(crate) constraints: NetworkProxyConstraints,
pub(crate) cfg_path: PathBuf,
pub(crate) blocked: VecDeque<BlockedRequest>,
}
#[derive(Clone)]
pub struct AppState {
state: Arc<RwLock<ConfigState>>,
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Avoid logging internal state (config contents, derived globsets, etc.) which can be noisy
// and may contain sensitive paths.
f.debug_struct("AppState").finish_non_exhaustive()
}
}
impl AppState {
pub async fn new() -> Result<Self> {
let cfg_state = build_config_state().await?;
Ok(Self {
state: Arc::new(RwLock::new(cfg_state)),
})
}
pub async fn current_cfg(&self) -> Result<Config> {
// Callers treat `AppState` as a live view of policy. We reload-on-demand so edits to
// `config.toml` (including Codex-managed writes) take effect without a restart.
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.clone())
}
pub async fn current_patterns(&self) -> Result<(Vec<String>, Vec<String>)> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok((
guard.config.network_proxy.policy.allowed_domains.clone(),
guard.config.network_proxy.policy.denied_domains.clone(),
))
}
pub async fn enabled(&self) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.network_proxy.enabled)
}
pub async fn force_reload(&self) -> Result<()> {
let mut guard = self.state.write().await;
let previous_cfg = guard.config.clone();
let blocked = guard.blocked.clone();
match build_config_state().await {
Ok(mut new_state) => {
// Policy changes are operationally sensitive; logging diffs makes changes traceable
// without needing to dump full config blobs (which can include unrelated settings).
log_policy_changes(&previous_cfg, &new_state.config);
new_state.blocked = blocked;
*guard = new_state;
let path = guard.cfg_path.display();
info!("reloaded config from {path}");
Ok(())
}
Err(err) => {
let path = guard.cfg_path.display();
warn!("failed to reload config from {path}: {err}; keeping previous config");
Err(err)
}
}
}
pub async fn host_blocked(&self, host: &str, port: u16) -> Result<(bool, String)> {
self.reload_if_needed().await?;
let (deny_set, allow_set, allow_local_binding, allowed_domains_empty, allowed_domains) = {
let guard = self.state.read().await;
(
guard.deny_set.clone(),
guard.allow_set.clone(),
guard.config.network_proxy.policy.allow_local_binding,
guard.config.network_proxy.policy.allowed_domains.is_empty(),
guard.config.network_proxy.policy.allowed_domains.clone(),
)
};
// Decision order matters:
// 1) explicit deny always wins
// 2) local/private networking is opt-in (defense-in-depth)
// 3) allowlist is enforced when configured
if deny_set.is_match(host) {
return Ok((true, "denied".to_string()));
}
let is_allowlisted = allow_set.is_match(host);
if !allow_local_binding {
// If the intent is "prevent access to local/internal networks", we must not rely solely
// on string checks like `localhost` / `127.0.0.1`. Attackers can use DNS rebinding or
// public suffix services that map hostnames onto private IPs.
//
// We therefore do a best-effort DNS + IP classification check before allowing the
// request. Explicit local/loopback literals are allowed only when explicitly
// allowlisted; hostnames that resolve to local/private IPs are blocked even if
// allowlisted.
let local_literal = {
let host = host.trim();
let host = host.split_once('%').map(|(ip, _)| ip).unwrap_or(host);
if is_loopback_host(host) {
true
} else if let Ok(ip) = host.parse::<IpAddr>() {
is_non_public_ip(ip)
} else {
false
}
};
if local_literal {
if !is_explicit_local_allowlisted(&allowed_domains, host) {
return Ok((true, "not_allowed_local".to_string()));
}
} else if host_resolves_to_non_public_ip(host, port).await? {
return Ok((true, "not_allowed_local".to_string()));
}
}
if allowed_domains_empty {
return Ok((true, "not_allowed".to_string()));
}
if !is_allowlisted {
return Ok((true, "not_allowed".to_string()));
}
Ok((false, String::new()))
}
pub async fn record_blocked(&self, entry: BlockedRequest) -> Result<()> {
self.reload_if_needed().await?;
let mut guard = self.state.write().await;
guard.blocked.push_back(entry);
while guard.blocked.len() > MAX_BLOCKED_EVENTS {
guard.blocked.pop_front();
}
Ok(())
}
pub async fn drain_blocked(&self) -> Result<Vec<BlockedRequest>> {
self.reload_if_needed().await?;
let mut guard = self.state.write().await;
let blocked = std::mem::take(&mut guard.blocked);
Ok(blocked.into_iter().collect())
}
pub async fn is_unix_socket_allowed(&self, path: &str) -> Result<bool> {
self.reload_if_needed().await?;
if cfg!(not(target_os = "macos")) {
return Ok(false);
}
// We only support absolute unix socket paths (a relative path would be ambiguous with
// respect to the proxy process's CWD and can lead to confusing allowlist behavior).
if !Path::new(path).is_absolute() {
return Ok(false);
}
let guard = self.state.read().await;
let requested_canonical = std::fs::canonicalize(path).ok();
for allowed in &guard.config.network_proxy.policy.allow_unix_sockets {
if allowed == path {
return Ok(true);
}
// Best-effort canonicalization to reduce surprises with symlinks.
// If canonicalization fails (e.g., socket not created yet), fall back to raw comparison.
let Some(requested_canonical) = &requested_canonical else {
continue;
};
if let Ok(allowed_canonical) = std::fs::canonicalize(allowed)
&& &allowed_canonical == requested_canonical
{
return Ok(true);
}
}
Ok(false)
}
pub async fn method_allowed(&self, method: &str) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(method_allowed(guard.config.network_proxy.mode, method))
}
pub async fn allow_upstream_proxy(&self) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.network_proxy.allow_upstream_proxy)
}
pub async fn network_mode(&self) -> Result<NetworkMode> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.network_proxy.mode)
}
pub async fn set_network_mode(&self, mode: NetworkMode) -> Result<()> {
self.reload_if_needed().await?;
let mut guard = self.state.write().await;
let mut candidate = guard.config.clone();
candidate.network_proxy.mode = mode;
validate_policy_against_constraints(&candidate, &guard.constraints)
.context("network_proxy.mode constrained by managed config")?;
guard.config.network_proxy.mode = mode;
info!("updated network mode to {mode:?}");
Ok(())
}
pub async fn mitm_state(&self) -> Result<Option<Arc<MitmState>>> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.mitm.clone())
}
async fn reload_if_needed(&self) -> Result<()> {
let needs_reload = {
let guard = self.state.read().await;
if !guard.cfg_path.exists() {
// If the config file is missing, only reload when it *used to* exist (mtime set).
// This avoids forcing a reload on every request when running with the default config.
guard.mtime.is_some()
} else {
let metadata = std::fs::metadata(&guard.cfg_path).ok();
match (metadata.and_then(|m| m.modified().ok()), guard.mtime) {
(Some(new_mtime), Some(old_mtime)) => new_mtime > old_mtime,
(Some(_), None) => true,
_ => false,
}
}
};
if !needs_reload {
return Ok(());
}
self.force_reload().await
}
}
async fn host_resolves_to_non_public_ip(host: &str, port: u16) -> Result<bool> {
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(is_non_public_ip(ip));
}
// If DNS lookup fails, default to "not local/private" rather than blocking. In practice, the
// subsequent connect attempt will fail anyway, and blocking on transient resolver issues would
// make the proxy fragile. The allowlist/denylist remains the primary control plane.
let addrs = match timeout(DNS_LOOKUP_TIMEOUT, lookup_host((host, port))).await {
Ok(Ok(addrs)) => addrs,
Ok(Err(_)) | Err(_) => return Ok(false),
};
for addr in addrs {
if is_non_public_ip(addr.ip()) {
return Ok(true);
}
}
Ok(false)
}
fn log_policy_changes(previous: &Config, next: &Config) {
log_domain_list_changes(
"allowlist",
&previous.network_proxy.policy.allowed_domains,
&next.network_proxy.policy.allowed_domains,
);
log_domain_list_changes(
"denylist",
&previous.network_proxy.policy.denied_domains,
&next.network_proxy.policy.denied_domains,
);
}
fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]) {
let previous_set: HashSet<String> = previous
.iter()
.map(|entry| entry.to_ascii_lowercase())
.collect();
let next_set: HashSet<String> = next
.iter()
.map(|entry| entry.to_ascii_lowercase())
.collect();
let mut seen_next = HashSet::new();
for entry in next {
let key = entry.to_ascii_lowercase();
if seen_next.insert(key.clone()) && !previous_set.contains(&key) {
info!("config entry added to {list_name}: {entry}");
}
}
let mut seen_previous = HashSet::new();
for entry in previous {
let key = entry.to_ascii_lowercase();
if seen_previous.insert(key.clone()) && !next_set.contains(&key) {
info!("config entry removed from {list_name}: {entry}");
}
}
}
fn is_explicit_local_allowlisted(allowed_domains: &[String], host: &str) -> bool {
let normalized_host = normalize_host(host);
allowed_domains.iter().any(|pattern| {
let pattern = pattern.trim();
if pattern == "*" || pattern.starts_with("*.") || pattern.starts_with("**.") {
return false;
}
if pattern.contains('*') || pattern.contains('?') {
return false;
}
normalize_host(pattern) == normalized_host
})
}
fn unix_timestamp() -> i64 {
OffsetDateTime::now_utc().unix_timestamp()
}
#[cfg(test)]
pub(crate) fn app_state_for_policy(policy: crate::config::NetworkPolicy) -> AppState {
let config = Config {
network_proxy: crate::config::NetworkProxyConfig {
enabled: true,
mode: NetworkMode::Full,
policy,
..crate::config::NetworkProxyConfig::default()
},
};
let allow_set =
crate::policy::compile_globset(&config.network_proxy.policy.allowed_domains).unwrap();
let deny_set =
crate::policy::compile_globset(&config.network_proxy.policy.denied_domains).unwrap();
let state = ConfigState {
config,
mtime: None,
allow_set,
deny_set,
mitm: None,
constraints: NetworkProxyConstraints::default(),
cfg_path: PathBuf::from("/nonexistent/config.toml"),
blocked: VecDeque::new(),
};
AppState {
state: Arc::new(RwLock::new(state)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::NetworkPolicy;
use crate::config::NetworkProxyConfig;
use crate::policy::compile_globset;
use crate::state::NetworkProxyConstraints;
use crate::state::validate_policy_against_constraints;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn host_blocked_denied_wins_over_allowed() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
denied_domains: vec!["example.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("example.com", 80).await.unwrap(),
(true, "denied".to_string())
);
}
#[tokio::test]
async fn host_blocked_requires_allowlist_match() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("example.com", 80).await.unwrap(),
(false, String::new())
);
assert_eq!(
// Use a public IP literal to avoid relying on ambient DNS behavior (some networks
// resolve unknown hostnames to private IPs, which would trigger `not_allowed_local`).
state.host_blocked("8.8.8.8", 80).await.unwrap(),
(true, "not_allowed".to_string())
);
}
#[tokio::test]
async fn host_blocked_subdomain_wildcards_exclude_apex() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["*.openai.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("api.openai.com", 80).await.unwrap(),
(false, String::new())
);
assert_eq!(
state.host_blocked("openai.com", 80).await.unwrap(),
(true, "not_allowed".to_string())
);
}
#[tokio::test]
async fn host_blocked_rejects_loopback_when_local_binding_disabled() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("127.0.0.1", 80).await.unwrap(),
(true, "not_allowed_local".to_string())
);
assert_eq!(
state.host_blocked("localhost", 80).await.unwrap(),
(true, "not_allowed_local".to_string())
);
}
#[tokio::test]
async fn host_blocked_rejects_loopback_when_allowlist_is_wildcard() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["*".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("127.0.0.1", 80).await.unwrap(),
(true, "not_allowed_local".to_string())
);
}
#[tokio::test]
async fn host_blocked_rejects_private_ip_literal_when_allowlist_is_wildcard() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["*".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("10.0.0.1", 80).await.unwrap(),
(true, "not_allowed_local".to_string())
);
}
#[tokio::test]
async fn host_blocked_allows_loopback_when_explicitly_allowlisted_and_local_binding_disabled() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["localhost".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("localhost", 80).await.unwrap(),
(false, String::new())
);
}
#[tokio::test]
async fn host_blocked_allows_private_ip_literal_when_explicitly_allowlisted() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["10.0.0.1".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("10.0.0.1", 80).await.unwrap(),
(false, String::new())
);
}
#[tokio::test]
async fn host_blocked_rejects_scoped_ipv6_literal_when_not_allowlisted() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("fe80::1%lo0", 80).await.unwrap(),
(true, "not_allowed_local".to_string())
);
}
#[tokio::test]
async fn host_blocked_allows_scoped_ipv6_literal_when_explicitly_allowlisted() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["fe80::1%lo0".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("fe80::1%lo0", 80).await.unwrap(),
(false, String::new())
);
}
#[tokio::test]
async fn host_blocked_rejects_private_ip_literals_when_local_binding_disabled() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("10.0.0.1", 80).await.unwrap(),
(true, "not_allowed_local".to_string())
);
}
#[tokio::test]
async fn host_blocked_rejects_loopback_when_allowlist_empty() {
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec![],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("127.0.0.1", 80).await.unwrap(),
(true, "not_allowed_local".to_string())
);
}
#[test]
fn validate_policy_against_constraints_disallows_widening_allowed_domains() {
let constraints = NetworkProxyConstraints {
allowed_domains: Some(vec!["example.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
policy: NetworkPolicy {
allowed_domains: vec!["example.com".to_string(), "evil.com".to_string()],
..NetworkPolicy::default()
},
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_widening_mode() {
let constraints = NetworkProxyConstraints {
mode: Some(NetworkMode::Limited),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
mode: NetworkMode::Full,
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_allows_narrowing_wildcard_allowlist() {
let constraints = NetworkProxyConstraints {
allowed_domains: Some(vec!["*.example.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
policy: NetworkPolicy {
allowed_domains: vec!["api.example.com".to_string()],
..NetworkPolicy::default()
},
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_ok());
}
#[test]
fn validate_policy_against_constraints_rejects_widening_wildcard_allowlist() {
let constraints = NetworkProxyConstraints {
allowed_domains: Some(vec!["*.example.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
policy: NetworkPolicy {
allowed_domains: vec!["**.example.com".to_string()],
..NetworkPolicy::default()
},
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_requires_managed_denied_domains_entries() {
let constraints = NetworkProxyConstraints {
denied_domains: Some(vec!["evil.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
policy: NetworkPolicy {
denied_domains: vec![],
..NetworkPolicy::default()
},
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_enabling_when_managed_disabled() {
let constraints = NetworkProxyConstraints {
enabled: Some(false),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_allow_local_binding_when_managed_disabled() {
let constraints = NetworkProxyConstraints {
allow_local_binding: Some(false),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
policy: NetworkPolicy {
allow_local_binding: true,
..NetworkPolicy::default()
},
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_non_loopback_admin_without_managed_opt_in() {
let constraints = NetworkProxyConstraints {
dangerously_allow_non_loopback_admin: Some(false),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
dangerously_allow_non_loopback_admin: true,
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_allows_non_loopback_admin_with_managed_opt_in() {
let constraints = NetworkProxyConstraints {
dangerously_allow_non_loopback_admin: Some(true),
..NetworkProxyConstraints::default()
};
let config = Config {
network_proxy: NetworkProxyConfig {
enabled: true,
dangerously_allow_non_loopback_admin: true,
..NetworkProxyConfig::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_ok());
}
#[test]
fn compile_globset_is_case_insensitive() {
let patterns = vec!["ExAmPle.CoM".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("example.com"));
assert!(set.is_match("EXAMPLE.COM"));
}
#[test]
fn compile_globset_excludes_apex_for_subdomain_patterns() {
let patterns = vec!["*.openai.com".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("api.openai.com"));
assert!(!set.is_match("openai.com"));
assert!(!set.is_match("evilopenai.com"));
}
#[test]
fn compile_globset_includes_apex_for_double_wildcard_patterns() {
let patterns = vec!["**.openai.com".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("openai.com"));
assert!(set.is_match("api.openai.com"));
assert!(!set.is_match("evilopenai.com"));
}
#[test]
fn compile_globset_matches_all_with_star() {
let patterns = vec!["*".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("openai.com"));
assert!(set.is_match("api.openai.com"));
}
#[test]
fn compile_globset_dedupes_patterns_without_changing_behavior() {
let patterns = vec!["example.com".to_string(), "example.com".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("example.com"));
assert!(set.is_match("EXAMPLE.COM"));
assert!(!set.is_match("not-example.com"));
}
#[test]
fn compile_globset_rejects_invalid_patterns() {
let patterns = vec!["[".to_string()];
assert!(compile_globset(&patterns).is_err());
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn unix_socket_allowlist_is_respected_on_macos() {
let socket_path = "/tmp/example.sock".to_string();
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_unix_sockets: vec![socket_path.clone()],
..NetworkPolicy::default()
});
assert!(state.is_unix_socket_allowed(&socket_path).await.unwrap());
assert!(
!state
.is_unix_socket_allowed("/tmp/not-allowed.sock")
.await
.unwrap()
);
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn unix_socket_allowlist_resolves_symlinks() {
use std::os::unix::fs::symlink;
let unique = OffsetDateTime::now_utc().unix_timestamp_nanos();
let dir = std::env::temp_dir().join(format!("codex-network-proxy-test-{unique}"));
std::fs::create_dir_all(&dir).unwrap();
let real = dir.join("real.sock");
let link = dir.join("link.sock");
// The allowlist mechanism is path-based; for test purposes we don't need an actual unix
// domain socket. Any filesystem entry works for canonicalization.
std::fs::write(&real, b"not a socket").unwrap();
symlink(&real, &link).unwrap();
let real_s = real.to_str().unwrap().to_string();
let link_s = link.to_str().unwrap().to_string();
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_unix_sockets: vec![real_s],
..NetworkPolicy::default()
});
assert!(state.is_unix_socket_allowed(&link_s).await.unwrap());
let _ = std::fs::remove_file(&link);
let _ = std::fs::remove_file(&real);
let _ = std::fs::remove_dir(&dir);
}
#[cfg(not(target_os = "macos"))]
#[tokio::test]
async fn unix_socket_allowlist_is_rejected_on_non_macos() {
let socket_path = "/tmp/example.sock".to_string();
let state = app_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_unix_sockets: vec![socket_path.clone()],
..NetworkPolicy::default()
});
assert!(!state.is_unix_socket_allowed(&socket_path).await.unwrap());
}
}

View File

@@ -0,0 +1,302 @@
use crate::config::NetworkMode;
use crate::network_policy::NetworkDecision;
use crate::network_policy::NetworkPolicyDecider;
use crate::network_policy::NetworkPolicyRequest;
use crate::network_policy::NetworkProtocol;
use crate::network_policy::evaluate_host_policy;
use crate::policy::normalize_host;
use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Context as _;
use anyhow::Result;
use rama_core::Layer;
use rama_core::Service;
use rama_core::extensions::ExtensionsRef;
use rama_core::layer::AddInputExtensionLayer;
use rama_core::service::service_fn;
use rama_net::stream::SocketInfo;
use rama_socks5::Socks5Acceptor;
use rama_socks5::server::DefaultConnector;
use rama_socks5::server::DefaultUdpRelay;
use rama_socks5::server::udp::RelayRequest;
use rama_socks5::server::udp::RelayResponse;
use rama_tcp::client::Request as TcpRequest;
use rama_tcp::client::service::TcpConnector;
use rama_tcp::server::TcpListener;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::error;
use tracing::info;
use tracing::warn;
pub async fn run_socks5(
state: Arc<AppState>,
addr: SocketAddr,
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
enable_socks5_udp: bool,
) -> Result<()> {
let listener = TcpListener::build()
.bind(addr)
.await
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
.map_err(rama_core::error::OpaqueError::from)
.map_err(anyhow::Error::from)
.with_context(|| format!("bind SOCKS5 proxy: {addr}"))?;
info!("SOCKS5 proxy listening on {addr}");
match state.network_mode().await {
Ok(NetworkMode::Limited) => {
info!("SOCKS5 is blocked in limited mode; set mode=\"full\" to allow SOCKS5");
}
Ok(NetworkMode::Full) => {}
Err(err) => {
warn!("failed to read network mode: {err}");
}
}
let tcp_connector = TcpConnector::default();
let policy_tcp_connector = service_fn({
let policy_decider = policy_decider.clone();
move |req: TcpRequest| {
let tcp_connector = tcp_connector.clone();
let policy_decider = policy_decider.clone();
async move {
let app_state = req
.extensions()
.get::<Arc<AppState>>()
.cloned()
.ok_or_else(|| io::Error::other("missing state"))?;
let host = normalize_host(&req.authority.host.to_string());
let port = req.authority.port;
let client = req
.extensions()
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
match app_state.enabled().await {
Ok(true) => {}
Ok(false) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"proxy_disabled".to_string(),
client.clone(),
None,
None,
"socks5".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("SOCKS blocked; proxy disabled (client={client}, host={host})");
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"proxy disabled",
)
.into());
}
Err(err) => {
error!("failed to read enabled state: {err}");
return Err(io::Error::other("proxy error").into());
}
}
match app_state.network_mode().await {
Ok(NetworkMode::Limited) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
"method_not_allowed".to_string(),
client.clone(),
None,
Some(NetworkMode::Limited),
"socks5".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!(
"SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Err(
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
);
}
Ok(NetworkMode::Full) => {}
Err(err) => {
error!("failed to evaluate method policy: {err}");
return Err(io::Error::other("proxy error").into());
}
}
let request = NetworkPolicyRequest::new(
NetworkProtocol::Socks5Tcp,
host.clone(),
port,
client.clone(),
None,
None,
None,
);
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
Ok(NetworkDecision::Deny { reason }) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
None,
None,
"socks5".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("SOCKS blocked (client={client}, host={host}, reason={reason})");
return Err(
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
);
}
Ok(NetworkDecision::Allow) => {
let client = client.as_deref().unwrap_or_default();
info!("SOCKS allowed (client={client}, host={host}, port={port})");
}
Err(err) => {
error!("failed to evaluate host: {err}");
return Err(io::Error::other("proxy error").into());
}
}
tcp_connector.serve(req).await
}
}
});
let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector);
let base = Socks5Acceptor::new().with_connector(socks_connector);
if enable_socks5_udp {
let udp_state = state.clone();
let udp_decider = policy_decider.clone();
let udp_relay = DefaultUdpRelay::default().with_async_inspector(service_fn(
move |request: RelayRequest| {
let udp_state = udp_state.clone();
let udp_decider = udp_decider.clone();
async move {
let RelayRequest {
server_address,
payload,
extensions,
..
} = request;
let host = normalize_host(&server_address.ip_addr.to_string());
let port = server_address.port;
let client = extensions
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
match udp_state.enabled().await {
Ok(true) => {}
Ok(false) => {
let _ = udp_state
.record_blocked(BlockedRequest::new(
host.clone(),
"proxy_disabled".to_string(),
client.clone(),
None,
None,
"socks5-udp".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!(
"SOCKS UDP blocked; proxy disabled (client={client}, host={host})"
);
return Ok(RelayResponse {
maybe_payload: None,
extensions,
});
}
Err(err) => {
error!("failed to read enabled state: {err}");
return Err(io::Error::other("proxy error"));
}
}
match udp_state.network_mode().await {
Ok(NetworkMode::Limited) => {
let _ = udp_state
.record_blocked(BlockedRequest::new(
host.clone(),
"method_not_allowed".to_string(),
client.clone(),
None,
Some(NetworkMode::Limited),
"socks5-udp".to_string(),
))
.await;
return Ok(RelayResponse {
maybe_payload: None,
extensions,
});
}
Ok(NetworkMode::Full) => {}
Err(err) => {
error!("failed to evaluate method policy: {err}");
return Err(io::Error::other("proxy error"));
}
}
let request = NetworkPolicyRequest::new(
NetworkProtocol::Socks5Udp,
host.clone(),
port,
client.clone(),
None,
None,
None,
);
match evaluate_host_policy(&udp_state, udp_decider.as_ref(), &request).await {
Ok(NetworkDecision::Deny { reason }) => {
let _ = udp_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
None,
None,
"socks5-udp".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!(
"SOCKS UDP blocked (client={client}, host={host}, reason={reason})"
);
Ok(RelayResponse {
maybe_payload: None,
extensions,
})
}
Ok(NetworkDecision::Allow) => Ok(RelayResponse {
maybe_payload: Some(payload),
extensions,
}),
Err(err) => {
error!("failed to evaluate UDP host: {err}");
Err(io::Error::other("proxy error"))
}
}
}
},
));
let socks_acceptor = base.with_udp_associator(udp_relay);
listener
.serve(AddInputExtensionLayer::new(state).into_layer(socks_acceptor))
.await;
} else {
listener
.serve(AddInputExtensionLayer::new(state).into_layer(base))
.await;
}
Ok(())
}

View File

@@ -0,0 +1,425 @@
use crate::config::Config;
use crate::config::MitmConfig;
use crate::config::NetworkMode;
use crate::mitm::MitmState;
use crate::policy::DomainPattern;
use crate::policy::compile_globset;
use crate::runtime::ConfigState;
use anyhow::Context;
use anyhow::Result;
use codex_app_server_protocol::ConfigLayerSource;
use codex_core::config::CONFIG_TOML_FILE;
use codex_core::config::ConfigBuilder;
use codex_core::config::Constrained;
use codex_core::config::ConstraintError;
use codex_core::config_loader::RequirementSource;
use serde::Deserialize;
use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;
pub use crate::runtime::AppState;
pub use crate::runtime::BlockedRequest;
#[cfg(test)]
pub(crate) use crate::runtime::app_state_for_policy;
pub(crate) async fn build_config_state() -> Result<ConfigState> {
// Load config through `codex-core` so we inherit the same layer ordering and semantics as the
// rest of Codex (system/managed layers, user layers, session flags, etc.).
let codex_cfg = ConfigBuilder::default()
.build()
.await
.context("failed to load Codex config")?;
let cfg_path = codex_cfg.codex_home.join(CONFIG_TOML_FILE);
// Deserialize from the merged effective config, rather than parsing config.toml ourselves.
// This avoids a second parser/merger implementation (and the drift that comes with it).
let merged_toml = codex_cfg.config_layer_stack.effective_config();
let mut config: Config = merged_toml
.try_into()
.context("failed to deserialize network proxy config")?;
// Security boundary: user-controlled layers must not be able to widen restrictions set by
// trusted/managed layers (e.g., MDM). Enforce this before building runtime state.
let constraints = enforce_trusted_constraints(&codex_cfg.config_layer_stack, &config)?;
// Permit relative MITM paths for ergonomics; resolve them relative to CODEX_HOME so the
// proxy can be configured from multiple config locations without changing cert paths.
resolve_mitm_paths(&mut config, &codex_cfg.codex_home);
let mtime = cfg_path.metadata().and_then(|m| m.modified()).ok();
let deny_set = compile_globset(&config.network_proxy.policy.denied_domains)?;
let allow_set = compile_globset(&config.network_proxy.policy.allowed_domains)?;
let mitm = if config.network_proxy.mitm.enabled {
build_mitm_state(
&config.network_proxy.mitm,
config.network_proxy.allow_upstream_proxy,
)?
} else {
None
};
Ok(ConfigState {
config,
mtime,
allow_set,
deny_set,
mitm,
constraints,
cfg_path,
blocked: std::collections::VecDeque::new(),
})
}
fn resolve_mitm_paths(config: &mut Config, codex_home: &Path) {
let base = codex_home;
if config.network_proxy.mitm.ca_cert_path.is_relative() {
config.network_proxy.mitm.ca_cert_path = base.join(&config.network_proxy.mitm.ca_cert_path);
}
if config.network_proxy.mitm.ca_key_path.is_relative() {
config.network_proxy.mitm.ca_key_path = base.join(&config.network_proxy.mitm.ca_key_path);
}
}
fn build_mitm_state(
config: &MitmConfig,
allow_upstream_proxy: bool,
) -> Result<Option<Arc<MitmState>>> {
Ok(Some(Arc::new(MitmState::new(
config,
allow_upstream_proxy,
)?)))
}
#[derive(Debug, Default, Deserialize)]
struct PartialConfig {
#[serde(default)]
network_proxy: PartialNetworkProxyConfig,
}
#[derive(Debug, Default, Deserialize)]
struct PartialNetworkProxyConfig {
enabled: Option<bool>,
mode: Option<NetworkMode>,
allow_upstream_proxy: Option<bool>,
dangerously_allow_non_loopback_proxy: Option<bool>,
dangerously_allow_non_loopback_admin: Option<bool>,
#[serde(default)]
policy: PartialNetworkPolicy,
}
#[derive(Debug, Default, Deserialize)]
struct PartialNetworkPolicy {
#[serde(default)]
allowed_domains: Option<Vec<String>>,
#[serde(default)]
denied_domains: Option<Vec<String>>,
#[serde(default)]
allow_unix_sockets: Option<Vec<String>>,
#[serde(default)]
allow_local_binding: Option<bool>,
}
#[derive(Debug, Default, Clone)]
pub(crate) struct NetworkProxyConstraints {
pub(crate) enabled: Option<bool>,
pub(crate) mode: Option<NetworkMode>,
pub(crate) allow_upstream_proxy: Option<bool>,
pub(crate) dangerously_allow_non_loopback_proxy: Option<bool>,
pub(crate) dangerously_allow_non_loopback_admin: Option<bool>,
pub(crate) allowed_domains: Option<Vec<String>>,
pub(crate) denied_domains: Option<Vec<String>>,
pub(crate) allow_unix_sockets: Option<Vec<String>>,
pub(crate) allow_local_binding: Option<bool>,
}
fn enforce_trusted_constraints(
layers: &codex_core::config_loader::ConfigLayerStack,
config: &Config,
) -> Result<NetworkProxyConstraints> {
let constraints = network_proxy_constraints_from_trusted_layers(layers)?;
validate_policy_against_constraints(config, &constraints)
.context("network proxy constraints")?;
Ok(constraints)
}
fn network_proxy_constraints_from_trusted_layers(
layers: &codex_core::config_loader::ConfigLayerStack,
) -> Result<NetworkProxyConstraints> {
let mut constraints = NetworkProxyConstraints::default();
for layer in layers
.get_layers(codex_core::config_loader::ConfigLayerStackOrdering::LowestPrecedenceFirst)
{
// Only trusted layers contribute constraints. User-controlled layers can narrow policy but
// must never widen beyond what managed config allows.
if is_user_controlled_layer(&layer.name) {
continue;
}
let partial: PartialConfig = layer
.config
.clone()
.try_into()
.context("failed to deserialize trusted config layer")?;
if let Some(enabled) = partial.network_proxy.enabled {
constraints.enabled = Some(enabled);
}
if let Some(mode) = partial.network_proxy.mode {
constraints.mode = Some(mode);
}
if let Some(allow_upstream_proxy) = partial.network_proxy.allow_upstream_proxy {
constraints.allow_upstream_proxy = Some(allow_upstream_proxy);
}
if let Some(dangerously_allow_non_loopback_proxy) =
partial.network_proxy.dangerously_allow_non_loopback_proxy
{
constraints.dangerously_allow_non_loopback_proxy =
Some(dangerously_allow_non_loopback_proxy);
}
if let Some(dangerously_allow_non_loopback_admin) =
partial.network_proxy.dangerously_allow_non_loopback_admin
{
constraints.dangerously_allow_non_loopback_admin =
Some(dangerously_allow_non_loopback_admin);
}
if let Some(allowed_domains) = partial.network_proxy.policy.allowed_domains {
constraints.allowed_domains = Some(allowed_domains);
}
if let Some(denied_domains) = partial.network_proxy.policy.denied_domains {
constraints.denied_domains = Some(denied_domains);
}
if let Some(allow_unix_sockets) = partial.network_proxy.policy.allow_unix_sockets {
constraints.allow_unix_sockets = Some(allow_unix_sockets);
}
if let Some(allow_local_binding) = partial.network_proxy.policy.allow_local_binding {
constraints.allow_local_binding = Some(allow_local_binding);
}
}
Ok(constraints)
}
fn is_user_controlled_layer(layer: &ConfigLayerSource) -> bool {
matches!(
layer,
ConfigLayerSource::User { .. }
| ConfigLayerSource::Project { .. }
| ConfigLayerSource::SessionFlags
)
}
pub(crate) fn validate_policy_against_constraints(
config: &Config,
constraints: &NetworkProxyConstraints,
) -> std::result::Result<(), ConstraintError> {
fn invalid_value(
field_name: &'static str,
candidate: impl Into<String>,
allowed: impl Into<String>,
) -> ConstraintError {
ConstraintError::InvalidValue {
field_name,
candidate: candidate.into(),
allowed: allowed.into(),
requirement_source: RequirementSource::Unknown,
}
}
let enabled = config.network_proxy.enabled;
if let Some(max_enabled) = constraints.enabled {
let _ = Constrained::new(enabled, move |candidate| {
if *candidate && !max_enabled {
Err(invalid_value(
"network_proxy.enabled",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
})?;
}
if let Some(max_mode) = constraints.mode {
let _ = Constrained::new(config.network_proxy.mode, move |candidate| {
if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
Err(invalid_value(
"network_proxy.mode",
format!("{candidate:?}"),
format!("{max_mode:?} or more restrictive"),
))
} else {
Ok(())
}
})?;
}
let allow_upstream_proxy = constraints.allow_upstream_proxy;
let _ = Constrained::new(
config.network_proxy.allow_upstream_proxy,
move |candidate| match allow_upstream_proxy {
Some(true) | None => Ok(()),
Some(false) => {
if *candidate {
Err(invalid_value(
"network_proxy.allow_upstream_proxy",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
}
},
)?;
let allow_non_loopback_admin = constraints.dangerously_allow_non_loopback_admin;
let _ = Constrained::new(
config.network_proxy.dangerously_allow_non_loopback_admin,
move |candidate| match allow_non_loopback_admin {
Some(true) | None => Ok(()),
Some(false) => {
if *candidate {
Err(invalid_value(
"network_proxy.dangerously_allow_non_loopback_admin",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
}
},
)?;
let allow_non_loopback_proxy = constraints.dangerously_allow_non_loopback_proxy;
let _ = Constrained::new(
config.network_proxy.dangerously_allow_non_loopback_proxy,
move |candidate| match allow_non_loopback_proxy {
Some(true) | None => Ok(()),
Some(false) => {
if *candidate {
Err(invalid_value(
"network_proxy.dangerously_allow_non_loopback_proxy",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
}
},
)?;
if let Some(allow_local_binding) = constraints.allow_local_binding {
let _ = Constrained::new(
config.network_proxy.policy.allow_local_binding,
move |candidate| {
if *candidate && !allow_local_binding {
Err(invalid_value(
"network_proxy.policy.allow_local_binding",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
},
)?;
}
if let Some(allowed_domains) = &constraints.allowed_domains {
let managed_patterns: Vec<DomainPattern> = allowed_domains
.iter()
.map(|entry| DomainPattern::parse(entry))
.collect();
let _ = Constrained::new(
config.network_proxy.policy.allowed_domains.clone(),
move |candidate| {
let mut invalid = Vec::new();
for entry in candidate {
let candidate_pattern = DomainPattern::parse(entry);
if !managed_patterns
.iter()
.any(|managed| managed.allows(&candidate_pattern))
{
invalid.push(entry.clone());
}
}
if invalid.is_empty() {
Ok(())
} else {
Err(invalid_value(
"network_proxy.policy.allowed_domains",
format!("{invalid:?}"),
"subset of managed allowed_domains",
))
}
},
)?;
}
if let Some(denied_domains) = &constraints.denied_domains {
let required_set: HashSet<String> = denied_domains
.iter()
.map(|s| s.to_ascii_lowercase())
.collect();
let _ = Constrained::new(
config.network_proxy.policy.denied_domains.clone(),
move |candidate| {
let candidate_set: HashSet<String> =
candidate.iter().map(|s| s.to_ascii_lowercase()).collect();
let missing: Vec<String> = required_set
.iter()
.filter(|entry| !candidate_set.contains(*entry))
.cloned()
.collect();
if missing.is_empty() {
Ok(())
} else {
Err(invalid_value(
"network_proxy.policy.denied_domains",
"missing managed denied_domains entries",
format!("{missing:?}"),
))
}
},
)?;
}
if let Some(allow_unix_sockets) = &constraints.allow_unix_sockets {
let allowed_set: HashSet<String> = allow_unix_sockets
.iter()
.map(|s| s.to_ascii_lowercase())
.collect();
let _ = Constrained::new(
config.network_proxy.policy.allow_unix_sockets.clone(),
move |candidate| {
let mut invalid = Vec::new();
for entry in candidate {
if !allowed_set.contains(&entry.to_ascii_lowercase()) {
invalid.push(entry.clone());
}
}
if invalid.is_empty() {
Ok(())
} else {
Err(invalid_value(
"network_proxy.policy.allow_unix_sockets",
format!("{invalid:?}"),
"subset of managed allow_unix_sockets",
))
}
},
)?;
}
Ok(())
}
fn network_mode_rank(mode: NetworkMode) -> u8 {
match mode {
NetworkMode::Limited => 0,
NetworkMode::Full => 1,
}
}

View File

@@ -0,0 +1,188 @@
use rama_core::Layer;
use rama_core::Service;
use rama_core::error::BoxError;
use rama_core::error::ErrorContext as _;
use rama_core::error::OpaqueError;
use rama_core::extensions::ExtensionsMut;
use rama_core::extensions::ExtensionsRef;
use rama_core::service::BoxService;
use rama_http::Body;
use rama_http::Request;
use rama_http::Response;
use rama_http::layer::version_adapter::RequestVersionAdapter;
use rama_http_backend::client::HttpClientService;
use rama_http_backend::client::HttpConnector;
use rama_http_backend::client::proxy::layer::HttpProxyConnectorLayer;
use rama_net::address::ProxyAddress;
use rama_net::client::EstablishedClientConnection;
use rama_net::http::RequestContext;
use rama_tcp::client::service::TcpConnector;
use rama_tls_boring::client::TlsConnectorDataBuilder;
use rama_tls_boring::client::TlsConnectorLayer;
use tracing::warn;
#[cfg(target_os = "macos")]
use rama_unix::client::UnixConnector;
#[derive(Clone, Default)]
struct ProxyConfig {
http: Option<ProxyAddress>,
https: Option<ProxyAddress>,
all: Option<ProxyAddress>,
}
impl ProxyConfig {
fn from_env() -> Self {
let http = read_proxy_env(&["HTTP_PROXY", "http_proxy"]);
let https = read_proxy_env(&["HTTPS_PROXY", "https_proxy"]);
let all = read_proxy_env(&["ALL_PROXY", "all_proxy"]);
Self { http, https, all }
}
fn proxy_for_request(&self, req: &Request) -> Option<ProxyAddress> {
let is_secure = RequestContext::try_from(req)
.map(|ctx| ctx.protocol.is_secure())
.unwrap_or(false);
self.proxy_for_protocol(is_secure)
}
fn proxy_for_protocol(&self, is_secure: bool) -> Option<ProxyAddress> {
if is_secure {
self.https
.clone()
.or_else(|| self.http.clone())
.or_else(|| self.all.clone())
} else {
self.http.clone().or_else(|| self.all.clone())
}
}
}
fn read_proxy_env(keys: &[&str]) -> Option<ProxyAddress> {
for key in keys {
let Ok(value) = std::env::var(key) else {
continue;
};
let value = value.trim();
if value.is_empty() {
continue;
}
match ProxyAddress::try_from(value) {
Ok(proxy) => {
if proxy
.protocol
.as_ref()
.map(rama_net::Protocol::is_http)
.unwrap_or(true)
{
return Some(proxy);
}
warn!("ignoring {key}: non-http proxy protocol");
}
Err(err) => {
warn!("ignoring {key}: invalid proxy address ({err})");
}
}
}
None
}
pub(crate) fn proxy_for_connect() -> Option<ProxyAddress> {
ProxyConfig::from_env().proxy_for_protocol(true)
}
#[derive(Clone)]
pub(crate) struct UpstreamClient {
connector: BoxService<
Request<Body>,
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
BoxError,
>,
proxy_config: ProxyConfig,
}
impl UpstreamClient {
pub(crate) fn direct() -> Self {
Self::new(ProxyConfig::default())
}
pub(crate) fn from_env_proxy() -> Self {
Self::new(ProxyConfig::from_env())
}
#[cfg(target_os = "macos")]
pub(crate) fn unix_socket(path: &str) -> Self {
let connector = build_unix_connector(path);
Self {
connector,
proxy_config: ProxyConfig::default(),
}
}
fn new(proxy_config: ProxyConfig) -> Self {
let connector = build_http_connector();
Self {
connector,
proxy_config,
}
}
}
impl Service<Request<Body>> for UpstreamClient {
type Output = Response;
type Error = OpaqueError;
async fn serve(&self, mut req: Request<Body>) -> Result<Self::Output, Self::Error> {
if let Some(proxy) = self.proxy_config.proxy_for_request(&req) {
req.extensions_mut().insert(proxy);
}
let uri = req.uri().clone();
let EstablishedClientConnection {
input: mut req,
conn: http_connection,
} = self
.connector
.serve(req)
.await
.map_err(OpaqueError::from_boxed)?;
req.extensions_mut()
.extend(http_connection.extensions().clone());
http_connection
.serve(req)
.await
.map_err(OpaqueError::from_boxed)
.with_context(|| format!("http request failure for uri: {uri}"))
}
}
fn build_http_connector() -> BoxService<
Request<Body>,
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
BoxError,
> {
let transport = TcpConnector::default();
let proxy = HttpProxyConnectorLayer::optional().into_layer(transport);
let tls_config = TlsConnectorDataBuilder::new_http_auto().into_shared_builder();
let tls = TlsConnectorLayer::auto()
.with_connector_data(tls_config)
.into_layer(proxy);
let tls = RequestVersionAdapter::new(tls);
let connector = HttpConnector::new(tls);
connector.boxed()
}
#[cfg(target_os = "macos")]
fn build_unix_connector(
path: &str,
) -> BoxService<
Request<Body>,
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
BoxError,
> {
let transport = UnixConnector::fixed(path);
let connector = HttpConnector::new(transport);
connector.boxed()
}