mirror of
https://github.com/openai/codex.git
synced 2026-02-03 15:33:41 +00:00
Compare commits
67 Commits
codex-work
...
pr-network
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3788afd19d | ||
|
|
5738cf125c | ||
|
|
6cb436a4e2 | ||
|
|
4196294c26 | ||
|
|
d54757bbbd | ||
|
|
fc33c31258 | ||
|
|
4a0c292de3 | ||
|
|
1dd695246f | ||
|
|
872d0ae3db | ||
|
|
5d7f98a5a8 | ||
|
|
58562a2a43 | ||
|
|
4995f09c47 | ||
|
|
e3d19064be | ||
|
|
e4c003d108 | ||
|
|
bcdedf5211 | ||
|
|
f1cc7fbae8 | ||
|
|
90c24700ac | ||
|
|
741b661cfa | ||
|
|
8637043a0c | ||
|
|
fe1c1c859f | ||
|
|
c8b7c0091a | ||
|
|
57c971470d | ||
|
|
5d6611170b | ||
|
|
3d1e12b49e | ||
|
|
7f44c725fb | ||
|
|
e8cff7ef77 | ||
|
|
d85717dcf8 | ||
|
|
c656278537 | ||
|
|
8338bebc32 | ||
|
|
302e6eea7d | ||
|
|
bd0ff89517 | ||
|
|
cbb5f48ba3 | ||
|
|
6c1df8b73e | ||
|
|
be94fb6913 | ||
|
|
8f6413ce5d | ||
|
|
a61ab56b77 | ||
|
|
258b7ecdbd | ||
|
|
b49b83847d | ||
|
|
ab28660a52 | ||
|
|
e6194d5c89 | ||
|
|
0bbe48c03e | ||
|
|
1906a23fa1 | ||
|
|
0dd709317f | ||
|
|
74d748cefb | ||
|
|
981c7c3261 | ||
|
|
826e40683e | ||
|
|
a60515bc85 | ||
|
|
6ef1dd9917 | ||
|
|
ef2c2d3131 | ||
|
|
d2042b92b6 | ||
|
|
310c79eef5 | ||
|
|
ee102bcb63 | ||
|
|
3e9046128a | ||
|
|
e60d43c3a7 | ||
|
|
4f3097b585 | ||
|
|
9b2a353e6e | ||
|
|
10abb38b53 | ||
|
|
2d7980340d | ||
|
|
6f4edec9f1 | ||
|
|
fc35891b07 | ||
|
|
dc063ff890 | ||
|
|
9d473922e3 | ||
|
|
127b89b4ed | ||
|
|
9b20af68f0 | ||
|
|
83e8a702fb | ||
|
|
eceb76bf3d | ||
|
|
f65edf9c91 |
61
.github/scripts/install-musl-build-tools.sh
vendored
Normal file
61
.github/scripts/install-musl-build-tools.sh
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
: "${TARGET:?TARGET environment variable is required}"
|
||||
: "${GITHUB_ENV:?GITHUB_ENV environment variable is required}"
|
||||
|
||||
apt_update_args=()
|
||||
if [[ -n "${APT_UPDATE_ARGS:-}" ]]; then
|
||||
# shellcheck disable=SC2206
|
||||
apt_update_args=(${APT_UPDATE_ARGS})
|
||||
fi
|
||||
|
||||
apt_install_args=()
|
||||
if [[ -n "${APT_INSTALL_ARGS:-}" ]]; then
|
||||
# shellcheck disable=SC2206
|
||||
apt_install_args=(${APT_INSTALL_ARGS})
|
||||
fi
|
||||
|
||||
sudo apt-get update "${apt_update_args[@]}"
|
||||
sudo apt-get install -y "${apt_install_args[@]}" musl-tools pkg-config g++ clang libc++-dev libc++abi-dev lld
|
||||
|
||||
case "${TARGET}" in
|
||||
x86_64-unknown-linux-musl)
|
||||
arch="x86_64"
|
||||
;;
|
||||
aarch64-unknown-linux-musl)
|
||||
arch="aarch64"
|
||||
;;
|
||||
*)
|
||||
echo "Unexpected musl target: ${TARGET}" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
if command -v clang++ >/dev/null; then
|
||||
cxx="$(command -v clang++)"
|
||||
echo "CXXFLAGS=--target=${TARGET} -stdlib=libc++ -pthread" >> "$GITHUB_ENV"
|
||||
echo "CFLAGS=--target=${TARGET} -pthread" >> "$GITHUB_ENV"
|
||||
if command -v clang >/dev/null; then
|
||||
cc="$(command -v clang)"
|
||||
echo "CC=${cc}" >> "$GITHUB_ENV"
|
||||
echo "TARGET_CC=${cc}" >> "$GITHUB_ENV"
|
||||
target_cc_var="CC_${TARGET}"
|
||||
target_cc_var="${target_cc_var//-/_}"
|
||||
echo "${target_cc_var}=${cc}" >> "$GITHUB_ENV"
|
||||
fi
|
||||
elif command -v "${arch}-linux-musl-g++" >/dev/null; then
|
||||
cxx="$(command -v "${arch}-linux-musl-g++")"
|
||||
elif command -v musl-g++ >/dev/null; then
|
||||
cxx="$(command -v musl-g++)"
|
||||
elif command -v musl-gcc >/dev/null; then
|
||||
cxx="$(command -v musl-gcc)"
|
||||
echo "CFLAGS=-pthread" >> "$GITHUB_ENV"
|
||||
else
|
||||
echo "musl g++ not found after install; arch=${arch}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "CXX=${cxx}" >> "$GITHUB_ENV"
|
||||
echo "CMAKE_CXX_COMPILER=${cxx}" >> "$GITHUB_ENV"
|
||||
echo "CMAKE_ARGS=-DCMAKE_HAVE_THREADS_LIBRARY=1 -DCMAKE_USE_PTHREADS_INIT=1 -DCMAKE_THREAD_LIBS_INIT=-pthread -DTHREADS_PREFER_PTHREAD_FLAG=ON" >> "$GITHUB_ENV"
|
||||
8
.github/workflows/rust-ci.yml
vendored
8
.github/workflows/rust-ci.yml
vendored
@@ -265,11 +265,11 @@ jobs:
|
||||
name: Install musl build tools
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
TARGET: ${{ matrix.target }}
|
||||
APT_UPDATE_ARGS: -o Acquire::Retries=3
|
||||
APT_INSTALL_ARGS: --no-install-recommends
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
sudo apt-get -y update -o Acquire::Retries=3
|
||||
sudo apt-get -y install --no-install-recommends musl-tools pkg-config
|
||||
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
|
||||
|
||||
- name: Install cargo-chef
|
||||
if: ${{ matrix.profile == 'release' }}
|
||||
|
||||
6
.github/workflows/rust-release.yml
vendored
6
.github/workflows/rust-release.yml
vendored
@@ -96,9 +96,9 @@ jobs:
|
||||
|
||||
- if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}}
|
||||
name: Install musl build tools
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y musl-tools pkg-config
|
||||
env:
|
||||
TARGET: ${{ matrix.target }}
|
||||
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
|
||||
|
||||
- name: Cargo build
|
||||
shell: bash
|
||||
|
||||
6
.github/workflows/shell-tool-mcp.yml
vendored
6
.github/workflows/shell-tool-mcp.yml
vendored
@@ -99,9 +99,9 @@ jobs:
|
||||
|
||||
- if: ${{ matrix.install_musl }}
|
||||
name: Install musl build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y musl-tools pkg-config
|
||||
env:
|
||||
TARGET: ${{ matrix.target }}
|
||||
run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh"
|
||||
|
||||
- name: Build exec server binaries
|
||||
run: cargo build --release --target ${{ matrix.target }} --bin codex-exec-mcp-server --bin codex-execve-wrapper
|
||||
|
||||
1001
codex-rs/Cargo.lock
generated
1001
codex-rs/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -27,6 +27,7 @@ members = [
|
||||
"login",
|
||||
"mcp-server",
|
||||
"mcp-types",
|
||||
"network-proxy",
|
||||
"ollama",
|
||||
"process-hardening",
|
||||
"protocol",
|
||||
@@ -138,6 +139,7 @@ env-flags = "0.1.1"
|
||||
env_logger = "0.11.5"
|
||||
eventsource-stream = "0.2.3"
|
||||
futures = { version = "0.3", default-features = false }
|
||||
globset = "0.4"
|
||||
http = "1.3.1"
|
||||
icu_decimal = "2.1"
|
||||
icu_locale_core = "2.1"
|
||||
|
||||
@@ -73,6 +73,7 @@ ignore = [
|
||||
{ id = "RUSTSEC-2024-0388", reason = "derivative is unmaintained; pulled in via starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2025-0057", reason = "fxhash is unmaintained; pulled in via starlark_map/starlark v0.13.0 used by execpolicy/cli/core; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2024-0436", reason = "paste is unmaintained; pulled in via ratatui/rmcp/starlark used by tui/execpolicy; no fixed release yet" },
|
||||
{ id = "RUSTSEC-2025-0134", reason = "rustls-pemfile is unmaintained; pulled in via rama-tls-rustls used by codex-network-proxy; no safe upgrade until rama removes the dependency" },
|
||||
# TODO(joshka, nornagon): remove this exception when once we update the ratatui fork to a version that uses lru 0.13+.
|
||||
{ id = "RUSTSEC-2026-0002", reason = "lru 0.12.5 is pulled in via ratatui fork; cannot upgrade until the fork is updated" },
|
||||
]
|
||||
|
||||
45
codex-rs/network-proxy/Cargo.toml
Normal file
45
codex-rs/network-proxy/Cargo.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
[package]
|
||||
name = "codex-network-proxy"
|
||||
edition = "2024"
|
||||
version = { workspace = true }
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "codex-network-proxy"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
name = "codex_network_proxy"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
globset = { workspace = true }
|
||||
rcgen-rama = { package = "rcgen", version = "0.14", default-features = false, features = ["pem", "x509-parser", "ring"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
time = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["fmt"] }
|
||||
rama-core = { version = "=0.3.0-alpha.4" }
|
||||
rama-http = { version = "=0.3.0-alpha.4" }
|
||||
rama-http-backend = { version = "=0.3.0-alpha.4", features = ["tls"] }
|
||||
rama-net = { version = "=0.3.0-alpha.4", features = ["http", "tls"] }
|
||||
rama-socks5 = { version = "=0.3.0-alpha.4" }
|
||||
rama-tcp = { version = "=0.3.0-alpha.4", features = ["http"] }
|
||||
rama-tls-boring = { version = "=0.3.0-alpha.4", features = ["http"] }
|
||||
rama-utils = { version = "=0.3.0-alpha.4" }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
[target.'cfg(target_family = "unix")'.dependencies]
|
||||
rama-unix = { version = "=0.3.0-alpha.4" }
|
||||
202
codex-rs/network-proxy/README.md
Normal file
202
codex-rs/network-proxy/README.md
Normal file
@@ -0,0 +1,202 @@
|
||||
# codex-network-proxy
|
||||
|
||||
`codex-network-proxy` is Codex's local network policy enforcement proxy. It runs:
|
||||
|
||||
- an HTTP proxy (default `127.0.0.1:3128`)
|
||||
- a SOCKS5 proxy (default `127.0.0.1:8081`)
|
||||
- an admin HTTP API (default `127.0.0.1:8080`)
|
||||
|
||||
It enforces an allow/deny policy and a "limited" mode intended for read-only network access.
|
||||
|
||||
## Quickstart
|
||||
|
||||
### 1) Configure
|
||||
|
||||
`codex-network-proxy` reads from Codex's merged `config.toml` (via `codex-core` config loading).
|
||||
|
||||
Example config:
|
||||
|
||||
```toml
|
||||
[network_proxy]
|
||||
enabled = true
|
||||
proxy_url = "http://127.0.0.1:3128"
|
||||
admin_url = "http://127.0.0.1:8080"
|
||||
# SOCKS5 listens on 127.0.0.1:8081 by default. Override via `NetworkProxyBuilder::socks_addr`.
|
||||
# When `enabled` is false, the proxy no-ops and does not bind listeners.
|
||||
# When true, respect HTTP(S)_PROXY/ALL_PROXY for upstream requests (HTTP(S) proxies only),
|
||||
# including CONNECT tunnels in full mode.
|
||||
allow_upstream_proxy = false
|
||||
# By default, non-loopback binds are clamped to loopback for safety.
|
||||
# If you want to expose these listeners beyond localhost, you must opt in explicitly.
|
||||
dangerously_allow_non_loopback_proxy = false
|
||||
dangerously_allow_non_loopback_admin = false
|
||||
mode = "limited" # or "full"
|
||||
|
||||
[network_proxy.mitm]
|
||||
# When enabled, HTTPS CONNECT can be terminated so limited-mode method policy still applies.
|
||||
# CA cert/key paths are relative to CODEX_HOME by default.
|
||||
enabled = false
|
||||
ca_cert_path = "network_proxy/mitm/ca.pem"
|
||||
ca_key_path = "network_proxy/mitm/ca.key"
|
||||
# Maximum size of request/response bodies MITM will buffer for inspection.
|
||||
max_body_bytes = 1048576
|
||||
|
||||
[network_proxy.policy]
|
||||
# Hosts must match the allowlist (unless denied).
|
||||
# If `allowed_domains` is empty, the proxy blocks requests until an allowlist is configured.
|
||||
allowed_domains = ["*.openai.com"]
|
||||
denied_domains = ["evil.example"]
|
||||
|
||||
# If false, local/private networking is rejected. Explicit allowlisting of local IP literals
|
||||
# (or `localhost`) is required to permit them.
|
||||
# Hostnames that resolve to local/private IPs are still blocked even if allowlisted.
|
||||
allow_local_binding = false
|
||||
|
||||
# macOS-only: allows proxying to a unix socket when request includes `x-unix-socket: /path`.
|
||||
allow_unix_sockets = ["/tmp/example.sock"]
|
||||
```
|
||||
|
||||
### 2) Run the proxy
|
||||
|
||||
```bash
|
||||
cargo run -p codex-network-proxy --
|
||||
```
|
||||
|
||||
If you plan to enable MITM, initialize the default directory first:
|
||||
|
||||
```bash
|
||||
cargo run -p codex-network-proxy -- init
|
||||
```
|
||||
|
||||
The proxy will generate a local CA on first MITM use if the files do not exist. Import the
|
||||
generated CA cert into your system trust store to avoid TLS errors.
|
||||
|
||||
### 3) Point a client at it
|
||||
|
||||
For HTTP(S) traffic:
|
||||
|
||||
```bash
|
||||
export HTTP_PROXY="http://127.0.0.1:3128"
|
||||
export HTTPS_PROXY="http://127.0.0.1:3128"
|
||||
```
|
||||
|
||||
For SOCKS5 traffic:
|
||||
|
||||
```bash
|
||||
export ALL_PROXY="socks5h://127.0.0.1:8081"
|
||||
```
|
||||
|
||||
To enable SOCKS5 UDP associate support:
|
||||
|
||||
```bash
|
||||
cargo run -p codex-network-proxy -- --enable-socks5-udp
|
||||
```
|
||||
|
||||
### 4) Understand blocks / debugging
|
||||
|
||||
When a request is blocked, the proxy responds with `403` and includes:
|
||||
|
||||
- `x-proxy-error`: one of:
|
||||
- `blocked-by-allowlist`
|
||||
- `blocked-by-denylist`
|
||||
- `blocked-by-method-policy`
|
||||
- `blocked-by-mitm-required`
|
||||
- `blocked-by-policy`
|
||||
|
||||
In "limited" mode, only `GET`, `HEAD`, and `OPTIONS` are allowed. HTTPS CONNECT requests require
|
||||
MITM to enforce limited-mode method policy; otherwise they are blocked.
|
||||
|
||||
## Library API
|
||||
|
||||
`codex-network-proxy` can be embedded as a library with a thin API:
|
||||
|
||||
```rust
|
||||
use codex_network_proxy::{NetworkProxy, NetworkDecision, NetworkPolicyRequest};
|
||||
|
||||
let proxy = NetworkProxy::builder()
|
||||
.http_addr("127.0.0.1:8080".parse()?)
|
||||
.admin_addr("127.0.0.1:9000".parse()?)
|
||||
.policy_decider(|request: NetworkPolicyRequest| async move {
|
||||
// Example: auto-allow when exec policy already approved a command prefix.
|
||||
if let Some(command) = request.command.as_deref() {
|
||||
if command.starts_with("curl ") {
|
||||
return NetworkDecision::Allow;
|
||||
}
|
||||
}
|
||||
NetworkDecision::Deny {
|
||||
reason: "policy_denied".to_string(),
|
||||
}
|
||||
})
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
let handle = proxy.run().await?;
|
||||
handle.shutdown().await?;
|
||||
```
|
||||
|
||||
When unix socket proxying is enabled, HTTP/admin bind overrides are still clamped to loopback
|
||||
to avoid turning the proxy into a remote bridge to local daemons.
|
||||
|
||||
### Policy hook (exec-policy mapping)
|
||||
|
||||
The proxy exposes a policy hook (`NetworkPolicyDecider`) that can override allowlist-only blocks.
|
||||
It receives `command` and `exec_policy_hint` fields when supplied by the embedding app. This lets
|
||||
core map exec approvals to network access, e.g. if a user already approved `curl *` for a session,
|
||||
the decider can auto-allow network requests originating from that command.
|
||||
|
||||
**Important:** Explicit deny rules still win. The decider only gets a chance to override
|
||||
`not_allowed` (allowlist misses), not `denied` or `not_allowed_local`.
|
||||
|
||||
## Admin API
|
||||
|
||||
The admin API is a small HTTP server intended for debugging and runtime adjustments.
|
||||
|
||||
Endpoints:
|
||||
|
||||
```bash
|
||||
curl -sS http://127.0.0.1:8080/health
|
||||
curl -sS http://127.0.0.1:8080/config
|
||||
curl -sS http://127.0.0.1:8080/patterns
|
||||
curl -sS http://127.0.0.1:8080/blocked
|
||||
|
||||
# Switch modes without restarting:
|
||||
curl -sS -X POST http://127.0.0.1:8080/mode -d '{"mode":"full"}'
|
||||
|
||||
# Force a config reload:
|
||||
curl -sS -X POST http://127.0.0.1:8080/reload
|
||||
```
|
||||
|
||||
## Platform notes
|
||||
|
||||
- Unix socket proxying via the `x-unix-socket` header is **macOS-only**; other platforms will
|
||||
reject unix socket requests.
|
||||
- HTTPS tunneling uses BoringSSL via Rama's `rama-tls-boring`; building the proxy requires a
|
||||
native toolchain and CMake on macOS/Linux/Windows.
|
||||
|
||||
## Security notes (important)
|
||||
|
||||
This section documents the protections implemented by `codex-network-proxy`, and the boundaries of
|
||||
what it can reasonably guarantee.
|
||||
|
||||
- Allowlist-first policy: if `allowed_domains` is empty, requests are blocked until an allowlist is configured.
|
||||
- Deny wins: entries in `denied_domains` always override the allowlist.
|
||||
- Local/private network protection: when `allow_local_binding = false`, the proxy blocks loopback
|
||||
and common private/link-local ranges. Explicit allowlisting of local IP literals (or `localhost`)
|
||||
is required to permit them; hostnames that resolve to local/private IPs are still blocked even if
|
||||
allowlisted (best-effort DNS lookup).
|
||||
- Limited mode enforcement:
|
||||
- only `GET`, `HEAD`, and `OPTIONS` are allowed
|
||||
- HTTPS CONNECT is blocked unless MITM is enabled
|
||||
- Listener safety defaults:
|
||||
- the admin API is unauthenticated; non-loopback binds are clamped unless explicitly enabled via
|
||||
`dangerously_allow_non_loopback_admin`
|
||||
- the HTTP proxy listener similarly clamps non-loopback binds unless explicitly enabled via
|
||||
`dangerously_allow_non_loopback_proxy`
|
||||
- when unix socket proxying is enabled, both listeners are forced to loopback to avoid turning the
|
||||
proxy into a remote bridge into local daemons.
|
||||
- `enabled` is enforced at runtime; when false the proxy no-ops and does not bind listeners.
|
||||
Limitations:
|
||||
|
||||
- DNS rebinding is hard to fully prevent without pinning the resolved IP(s) all the way down to the
|
||||
transport layer. If your threat model includes hostile DNS, enforce network egress at a lower
|
||||
layer too (e.g., firewall / VPC / corporate proxy policies).
|
||||
157
codex-rs/network-proxy/src/admin.rs
Normal file
157
codex-rs/network-proxy/src/admin.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::responses::json_response;
|
||||
use crate::responses::text_response;
|
||||
use crate::state::AppState;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use rama_core::rt::Executor;
|
||||
use rama_core::service::service_fn;
|
||||
use rama_http::Body;
|
||||
use rama_http::Request;
|
||||
use rama_http::Response;
|
||||
use rama_http::StatusCode;
|
||||
use rama_http_backend::server::HttpServer;
|
||||
use rama_tcp::server::TcpListener;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
|
||||
pub async fn run_admin_api(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
|
||||
// Debug-only admin API (health/config/patterns/blocked + mode/reload). Policy is config-driven
|
||||
// and constraint-enforced; this endpoint should not become a second policy/approval plane.
|
||||
let listener = TcpListener::build()
|
||||
.bind(addr)
|
||||
.await
|
||||
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
|
||||
.map_err(rama_core::error::OpaqueError::from)
|
||||
.map_err(anyhow::Error::from)
|
||||
.with_context(|| format!("bind admin API: {addr}"))?;
|
||||
|
||||
let server_state = state.clone();
|
||||
let server = HttpServer::auto(Executor::new()).service(service_fn(move |req| {
|
||||
let state = server_state.clone();
|
||||
async move { handle_admin_request(state, req).await }
|
||||
}));
|
||||
info!("admin API listening on {addr}");
|
||||
listener.serve(server).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_admin_request(state: Arc<AppState>, req: Request) -> Result<Response, Infallible> {
|
||||
const MODE_BODY_LIMIT: usize = 8 * 1024;
|
||||
|
||||
let method = req.method().clone();
|
||||
let path = req.uri().path().to_string();
|
||||
let response = match (method.as_str(), path.as_str()) {
|
||||
("GET", "/health") => Response::new(Body::from("ok")),
|
||||
("GET", "/config") => match state.current_cfg().await {
|
||||
Ok(cfg) => json_response(&cfg),
|
||||
Err(err) => {
|
||||
error!("failed to load config: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
("GET", "/patterns") => match state.current_patterns().await {
|
||||
Ok((allow, deny)) => json_response(&PatternsResponse {
|
||||
allowed: allow,
|
||||
denied: deny,
|
||||
}),
|
||||
Err(err) => {
|
||||
error!("failed to load patterns: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
("GET", "/blocked") => match state.drain_blocked().await {
|
||||
Ok(blocked) => json_response(&BlockedResponse { blocked }),
|
||||
Err(err) => {
|
||||
error!("failed to read blocked queue: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
|
||||
}
|
||||
},
|
||||
("POST", "/mode") => {
|
||||
let mut body = req.into_body();
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
loop {
|
||||
let chunk = match body.chunk().await {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
error!("failed to read mode body: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
|
||||
}
|
||||
};
|
||||
let Some(chunk) = chunk else {
|
||||
break;
|
||||
};
|
||||
|
||||
if buf.len().saturating_add(chunk.len()) > MODE_BODY_LIMIT {
|
||||
return Ok(text_response(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"body too large",
|
||||
));
|
||||
}
|
||||
buf.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if buf.is_empty() {
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "missing body"));
|
||||
}
|
||||
let update: ModeUpdate = match serde_json::from_slice(&buf) {
|
||||
Ok(update) => update,
|
||||
Err(err) => {
|
||||
error!("failed to parse mode update: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid json"));
|
||||
}
|
||||
};
|
||||
match state.set_network_mode(update.mode).await {
|
||||
Ok(()) => json_response(&ModeUpdateResponse {
|
||||
status: "ok",
|
||||
mode: update.mode,
|
||||
}),
|
||||
Err(err) => {
|
||||
error!("mode update failed: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "mode update failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
("POST", "/reload") => match state.force_reload().await {
|
||||
Ok(()) => json_response(&ReloadResponse { status: "reloaded" }),
|
||||
Err(err) => {
|
||||
error!("reload failed: {err}");
|
||||
text_response(StatusCode::INTERNAL_SERVER_ERROR, "reload failed")
|
||||
}
|
||||
},
|
||||
_ => text_response(StatusCode::NOT_FOUND, "not found"),
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModeUpdate {
|
||||
mode: NetworkMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct PatternsResponse {
|
||||
allowed: Vec<String>,
|
||||
denied: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct BlockedResponse<T> {
|
||||
blocked: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ModeUpdateResponse {
|
||||
status: &'static str,
|
||||
mode: NetworkMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ReloadResponse {
|
||||
status: &'static str,
|
||||
}
|
||||
417
codex-rs/network-proxy/src/config.rs
Normal file
417
codex-rs/network-proxy/src/config.rs
Normal file
@@ -0,0 +1,417 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::net::IpAddr;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Config {
|
||||
#[serde(default)]
|
||||
pub network_proxy: NetworkProxyConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NetworkProxyConfig {
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
#[serde(default = "default_proxy_url")]
|
||||
pub proxy_url: String,
|
||||
#[serde(default = "default_admin_url")]
|
||||
pub admin_url: String,
|
||||
#[serde(default)]
|
||||
pub allow_upstream_proxy: bool,
|
||||
#[serde(default)]
|
||||
pub dangerously_allow_non_loopback_proxy: bool,
|
||||
#[serde(default)]
|
||||
pub dangerously_allow_non_loopback_admin: bool,
|
||||
#[serde(default)]
|
||||
pub mode: NetworkMode,
|
||||
#[serde(default)]
|
||||
pub policy: NetworkPolicy,
|
||||
#[serde(default)]
|
||||
pub mitm: MitmConfig,
|
||||
}
|
||||
|
||||
impl Default for NetworkProxyConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
proxy_url: default_proxy_url(),
|
||||
admin_url: default_admin_url(),
|
||||
allow_upstream_proxy: false,
|
||||
dangerously_allow_non_loopback_proxy: false,
|
||||
dangerously_allow_non_loopback_admin: false,
|
||||
mode: NetworkMode::default(),
|
||||
policy: NetworkPolicy::default(),
|
||||
mitm: MitmConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct NetworkPolicy {
|
||||
#[serde(default)]
|
||||
pub allowed_domains: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub denied_domains: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub allow_unix_sockets: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub allow_local_binding: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum NetworkMode {
|
||||
Limited,
|
||||
#[default]
|
||||
Full,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MitmConfig {
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
#[serde(default)]
|
||||
pub inspect: bool,
|
||||
#[serde(default = "default_mitm_max_body_bytes")]
|
||||
pub max_body_bytes: usize,
|
||||
#[serde(default = "default_ca_cert_path")]
|
||||
pub ca_cert_path: PathBuf,
|
||||
#[serde(default = "default_ca_key_path")]
|
||||
pub ca_key_path: PathBuf,
|
||||
}
|
||||
|
||||
impl Default for MitmConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
inspect: false,
|
||||
max_body_bytes: default_mitm_max_body_bytes(),
|
||||
ca_cert_path: default_ca_cert_path(),
|
||||
ca_key_path: default_ca_key_path(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_proxy_url() -> String {
|
||||
"http://127.0.0.1:3128".to_string()
|
||||
}
|
||||
|
||||
fn default_admin_url() -> String {
|
||||
"http://127.0.0.1:8080".to_string()
|
||||
}
|
||||
|
||||
fn default_ca_cert_path() -> PathBuf {
|
||||
PathBuf::from("network_proxy/mitm/ca.pem")
|
||||
}
|
||||
|
||||
fn default_ca_key_path() -> PathBuf {
|
||||
PathBuf::from("network_proxy/mitm/ca.key")
|
||||
}
|
||||
|
||||
fn default_mitm_max_body_bytes() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
fn clamp_non_loopback(addr: SocketAddr, allow_non_loopback: bool, name: &str) -> SocketAddr {
|
||||
if addr.ip().is_loopback() {
|
||||
return addr;
|
||||
}
|
||||
|
||||
if allow_non_loopback {
|
||||
warn!("DANGEROUS: {name} listening on non-loopback address {addr}");
|
||||
return addr;
|
||||
}
|
||||
|
||||
warn!(
|
||||
"{name} requested non-loopback bind ({addr}); clamping to 127.0.0.1:{port} (set dangerously_allow_non_loopback_proxy or dangerously_allow_non_loopback_admin to override)",
|
||||
port = addr.port()
|
||||
);
|
||||
SocketAddr::from(([127, 0, 0, 1], addr.port()))
|
||||
}
|
||||
|
||||
pub(crate) fn clamp_bind_addrs(
|
||||
http_addr: SocketAddr,
|
||||
admin_addr: SocketAddr,
|
||||
cfg: &NetworkProxyConfig,
|
||||
) -> (SocketAddr, SocketAddr) {
|
||||
let http_addr = clamp_non_loopback(
|
||||
http_addr,
|
||||
cfg.dangerously_allow_non_loopback_proxy,
|
||||
"HTTP proxy",
|
||||
);
|
||||
let admin_addr = clamp_non_loopback(
|
||||
admin_addr,
|
||||
cfg.dangerously_allow_non_loopback_admin,
|
||||
"admin API",
|
||||
);
|
||||
if cfg.policy.allow_unix_sockets.is_empty() {
|
||||
return (http_addr, admin_addr);
|
||||
}
|
||||
|
||||
// `x-unix-socket` is intentionally a local escape hatch. If the proxy (or admin API) is
|
||||
// reachable from outside the machine, it can become a remote bridge into local daemons
|
||||
// (e.g. docker.sock). To avoid footguns, enforce loopback binding whenever unix sockets
|
||||
// are enabled.
|
||||
if cfg.dangerously_allow_non_loopback_proxy && !http_addr.ip().is_loopback() {
|
||||
warn!(
|
||||
"unix socket proxying is enabled; ignoring dangerously_allow_non_loopback_proxy and clamping HTTP proxy to loopback"
|
||||
);
|
||||
}
|
||||
if cfg.dangerously_allow_non_loopback_admin && !admin_addr.ip().is_loopback() {
|
||||
warn!(
|
||||
"unix socket proxying is enabled; ignoring dangerously_allow_non_loopback_admin and clamping admin API to loopback"
|
||||
);
|
||||
}
|
||||
(
|
||||
SocketAddr::from(([127, 0, 0, 1], http_addr.port())),
|
||||
SocketAddr::from(([127, 0, 0, 1], admin_addr.port())),
|
||||
)
|
||||
}
|
||||
|
||||
pub struct RuntimeConfig {
|
||||
pub http_addr: SocketAddr,
|
||||
pub socks_addr: SocketAddr,
|
||||
pub admin_addr: SocketAddr,
|
||||
}
|
||||
|
||||
pub fn resolve_runtime(cfg: &Config) -> RuntimeConfig {
|
||||
let http_addr = resolve_addr(&cfg.network_proxy.proxy_url, 3128);
|
||||
let admin_addr = resolve_addr(&cfg.network_proxy.admin_url, 8080);
|
||||
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg.network_proxy);
|
||||
let socks_addr = SocketAddr::from(([127, 0, 0, 1], 8081));
|
||||
|
||||
RuntimeConfig {
|
||||
http_addr,
|
||||
socks_addr,
|
||||
admin_addr,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_addr(url: &str, default_port: u16) -> SocketAddr {
|
||||
let addr_parts = parse_host_port(url, default_port);
|
||||
let host = if addr_parts.host.eq_ignore_ascii_case("localhost") {
|
||||
"127.0.0.1"
|
||||
} else {
|
||||
addr_parts.host
|
||||
};
|
||||
match host.parse::<IpAddr>() {
|
||||
Ok(ip) => SocketAddr::new(ip, addr_parts.port),
|
||||
Err(_) => SocketAddr::from(([127, 0, 0, 1], addr_parts.port)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct SocketAddressParts<'a> {
|
||||
host: &'a str,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
fn parse_host_port(url: &str, default_port: u16) -> SocketAddressParts<'_> {
|
||||
let trimmed = url.trim();
|
||||
if trimmed.is_empty() {
|
||||
return SocketAddressParts {
|
||||
host: "127.0.0.1",
|
||||
port: default_port,
|
||||
};
|
||||
}
|
||||
let without_scheme = trimmed
|
||||
.split_once("://")
|
||||
.map(|(_, rest)| rest)
|
||||
.unwrap_or(trimmed);
|
||||
let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
|
||||
let host_port = host_port
|
||||
.rsplit_once('@')
|
||||
.map(|(_, rest)| rest)
|
||||
.unwrap_or(host_port);
|
||||
|
||||
if host_port.starts_with('[')
|
||||
&& let Some(end) = host_port.find(']')
|
||||
{
|
||||
let host = &host_port[1..end];
|
||||
let port = host_port[end + 1..]
|
||||
.strip_prefix(':')
|
||||
.and_then(|port| port.parse::<u16>().ok())
|
||||
.unwrap_or(default_port);
|
||||
return SocketAddressParts { host, port };
|
||||
}
|
||||
|
||||
// Only treat `host:port` as such when there's a single `:`. This avoids
|
||||
// accidentally interpreting unbracketed IPv6 addresses as `host:port`.
|
||||
if host_port.bytes().filter(|b| *b == b':').count() == 1
|
||||
&& let Some((host, port)) = host_port.rsplit_once(':')
|
||||
&& let Ok(port) = port.parse::<u16>()
|
||||
{
|
||||
return SocketAddressParts { host, port };
|
||||
}
|
||||
|
||||
SocketAddressParts {
|
||||
host: host_port,
|
||||
port: default_port,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_defaults_for_empty_string() {
|
||||
assert_eq!(
|
||||
parse_host_port("", 1234),
|
||||
SocketAddressParts {
|
||||
host: "127.0.0.1",
|
||||
port: 1234,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_defaults_for_whitespace() {
|
||||
assert_eq!(
|
||||
parse_host_port(" ", 5555),
|
||||
SocketAddressParts {
|
||||
host: "127.0.0.1",
|
||||
port: 5555,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_host_port_without_scheme() {
|
||||
assert_eq!(
|
||||
parse_host_port("127.0.0.1:8080", 3128),
|
||||
SocketAddressParts {
|
||||
host: "127.0.0.1",
|
||||
port: 8080,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_host_port_with_scheme_and_path() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://example.com:8080/some/path", 3128),
|
||||
SocketAddressParts {
|
||||
host: "example.com",
|
||||
port: 8080,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_strips_userinfo() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://user:pass@host.example:5555", 3128),
|
||||
SocketAddressParts {
|
||||
host: "host.example",
|
||||
port: 5555,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_parses_ipv6_with_brackets() {
|
||||
assert_eq!(
|
||||
parse_host_port("http://[::1]:9999", 3128),
|
||||
SocketAddressParts {
|
||||
host: "::1",
|
||||
port: 9999,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_does_not_treat_unbracketed_ipv6_as_host_port() {
|
||||
assert_eq!(
|
||||
parse_host_port("2001:db8::1", 3128),
|
||||
SocketAddressParts {
|
||||
host: "2001:db8::1",
|
||||
port: 3128,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_host_port_falls_back_to_default_port_when_port_is_invalid() {
|
||||
assert_eq!(
|
||||
parse_host_port("example.com:notaport", 3128),
|
||||
SocketAddressParts {
|
||||
host: "example.com:notaport",
|
||||
port: 3128,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_maps_localhost_to_loopback() {
|
||||
assert_eq!(
|
||||
resolve_addr("localhost", 3128),
|
||||
"127.0.0.1:3128".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_parses_ip_literals() {
|
||||
assert_eq!(
|
||||
resolve_addr("1.2.3.4", 80),
|
||||
"1.2.3.4:80".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_parses_ipv6_literals() {
|
||||
assert_eq!(
|
||||
resolve_addr("http://[::1]:8080", 3128),
|
||||
"[::1]:8080".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_addr_falls_back_to_loopback_for_hostnames() {
|
||||
assert_eq!(
|
||||
resolve_addr("http://example.com:5555", 3128),
|
||||
"127.0.0.1:5555".parse::<SocketAddr>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_bind_addrs_allows_non_loopback_when_enabled() {
|
||||
let cfg = NetworkProxyConfig {
|
||||
dangerously_allow_non_loopback_proxy: true,
|
||||
dangerously_allow_non_loopback_admin: true,
|
||||
..Default::default()
|
||||
};
|
||||
let http_addr = "0.0.0.0:3128".parse::<SocketAddr>().unwrap();
|
||||
let admin_addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
|
||||
|
||||
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg);
|
||||
|
||||
assert_eq!(http_addr, "0.0.0.0:3128".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(admin_addr, "0.0.0.0:8080".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_bind_addrs_forces_loopback_when_unix_sockets_enabled() {
|
||||
let cfg = NetworkProxyConfig {
|
||||
dangerously_allow_non_loopback_proxy: true,
|
||||
dangerously_allow_non_loopback_admin: true,
|
||||
policy: NetworkPolicy {
|
||||
allow_unix_sockets: vec!["/tmp/docker.sock".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let http_addr = "0.0.0.0:3128".parse::<SocketAddr>().unwrap();
|
||||
let admin_addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
|
||||
|
||||
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg);
|
||||
|
||||
assert_eq!(http_addr, "127.0.0.1:3128".parse::<SocketAddr>().unwrap());
|
||||
assert_eq!(admin_addr, "127.0.0.1:8080".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
}
|
||||
626
codex-rs/network-proxy/src/http_proxy.rs
Normal file
626
codex-rs/network-proxy/src/http_proxy.rs
Normal file
@@ -0,0 +1,626 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::mitm;
|
||||
use crate::network_policy::NetworkDecision;
|
||||
use crate::network_policy::NetworkPolicyDecider;
|
||||
use crate::network_policy::NetworkPolicyRequest;
|
||||
use crate::network_policy::NetworkProtocol;
|
||||
use crate::network_policy::evaluate_host_policy;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::responses::blocked_header_value;
|
||||
use crate::responses::json_response;
|
||||
use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use crate::upstream::UpstreamClient;
|
||||
use crate::upstream::proxy_for_connect;
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use rama_core::Layer;
|
||||
use rama_core::Service;
|
||||
use rama_core::error::BoxError;
|
||||
use rama_core::error::ErrorExt as _;
|
||||
use rama_core::error::OpaqueError;
|
||||
use rama_core::extensions::ExtensionsMut;
|
||||
use rama_core::extensions::ExtensionsRef;
|
||||
use rama_core::layer::AddInputExtensionLayer;
|
||||
use rama_core::rt::Executor;
|
||||
use rama_core::service::service_fn;
|
||||
use rama_http::Body;
|
||||
use rama_http::HeaderValue;
|
||||
use rama_http::Request;
|
||||
use rama_http::Response;
|
||||
use rama_http::StatusCode;
|
||||
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
use rama_http::matcher::MethodMatcher;
|
||||
use rama_http_backend::client::proxy::layer::HttpProxyConnector;
|
||||
use rama_http_backend::server::HttpServer;
|
||||
use rama_http_backend::server::layer::upgrade::UpgradeLayer;
|
||||
use rama_http_backend::server::layer::upgrade::Upgraded;
|
||||
use rama_net::Protocol;
|
||||
use rama_net::address::ProxyAddress;
|
||||
use rama_net::client::ConnectorService;
|
||||
use rama_net::client::EstablishedClientConnection;
|
||||
use rama_net::http::RequestContext;
|
||||
use rama_net::proxy::ProxyRequest;
|
||||
use rama_net::proxy::ProxyTarget;
|
||||
use rama_net::proxy::StreamForwardService;
|
||||
use rama_net::stream::SocketInfo;
|
||||
use rama_tcp::client::Request as TcpRequest;
|
||||
use rama_tcp::client::service::TcpConnector;
|
||||
use rama_tcp::server::TcpListener;
|
||||
use rama_tls_boring::client::TlsConnectorDataBuilder;
|
||||
use rama_tls_boring::client::TlsConnectorLayer;
|
||||
use serde::Serialize;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub async fn run_http_proxy(
|
||||
state: Arc<AppState>,
|
||||
addr: SocketAddr,
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
) -> Result<()> {
|
||||
let listener = TcpListener::build()
|
||||
.bind(addr)
|
||||
.await
|
||||
// Rama's `BoxError` is a `Box<dyn Error + Send + Sync>` without an explicit `'static`
|
||||
// lifetime bound, which means it doesn't satisfy `anyhow::Context`'s `StdError` constraint.
|
||||
// Wrap it in Rama's `OpaqueError` so we can preserve the original error as a source and
|
||||
// still use `anyhow` for chaining.
|
||||
.map_err(rama_core::error::OpaqueError::from)
|
||||
.map_err(anyhow::Error::from)
|
||||
.with_context(|| format!("bind HTTP proxy: {addr}"))?;
|
||||
|
||||
let http_service = HttpServer::auto(Executor::new()).service(
|
||||
(
|
||||
UpgradeLayer::new(
|
||||
MethodMatcher::CONNECT,
|
||||
service_fn({
|
||||
let policy_decider = policy_decider.clone();
|
||||
move |req| http_connect_accept(policy_decider.clone(), req)
|
||||
}),
|
||||
service_fn(http_connect_proxy),
|
||||
),
|
||||
RemoveResponseHeaderLayer::hop_by_hop(),
|
||||
RemoveRequestHeaderLayer::hop_by_hop(),
|
||||
)
|
||||
.into_layer(service_fn({
|
||||
let policy_decider = policy_decider.clone();
|
||||
move |req| http_plain_proxy(policy_decider.clone(), req)
|
||||
})),
|
||||
);
|
||||
|
||||
info!("HTTP proxy listening on {addr}");
|
||||
|
||||
listener
|
||||
.serve(AddInputExtensionLayer::new(state).into_layer(http_service))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn http_connect_accept(
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
mut req: Request,
|
||||
) -> Result<(Response, Request), Response> {
|
||||
let app_state = req
|
||||
.extensions()
|
||||
.get::<Arc<AppState>>()
|
||||
.cloned()
|
||||
.ok_or_else(|| text_response(StatusCode::INTERNAL_SERVER_ERROR, "missing state"))?;
|
||||
|
||||
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
|
||||
Ok(authority) => authority,
|
||||
Err(err) => {
|
||||
warn!("CONNECT missing authority: {err}");
|
||||
return Err(text_response(StatusCode::BAD_REQUEST, "missing authority"));
|
||||
}
|
||||
};
|
||||
|
||||
let host = normalize_host(&authority.host.to_string());
|
||||
if host.is_empty() {
|
||||
return Err(text_response(StatusCode::BAD_REQUEST, "invalid host"));
|
||||
}
|
||||
|
||||
let client = client_addr(&req);
|
||||
|
||||
let enabled = match app_state.enabled().await {
|
||||
Ok(enabled) => enabled,
|
||||
Err(err) => {
|
||||
error!("failed to read enabled state: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
if !enabled {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"proxy_disabled".to_string(),
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
None,
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("CONNECT blocked; proxy disabled (client={client}, host={host})");
|
||||
return Err(text_response(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"proxy disabled",
|
||||
));
|
||||
}
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::HttpsConnect,
|
||||
host.clone(),
|
||||
authority.port,
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
|
||||
Ok(NetworkDecision::Deny { reason }) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
None,
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("CONNECT blocked (client={client}, host={host}, reason={reason})");
|
||||
return Err(blocked_text(&reason));
|
||||
}
|
||||
Ok(NetworkDecision::Allow) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("CONNECT allowed (client={client}, host={host})");
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate host for CONNECT {host}: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
|
||||
let mode = match app_state.network_mode().await {
|
||||
Ok(mode) => mode,
|
||||
Err(err) => {
|
||||
error!("failed to read network mode: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
|
||||
let mitm_state = match app_state.mitm_state().await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
error!("failed to load MITM state: {err}");
|
||||
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
|
||||
if mode == NetworkMode::Limited && mitm_state.is_none() {
|
||||
// Limited mode is designed to be read-only. Without MITM, a CONNECT tunnel would hide the
|
||||
// inner HTTP method/headers from the proxy, effectively bypassing method policy.
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"mitm_required".to_string(),
|
||||
client.clone(),
|
||||
Some("CONNECT".to_string()),
|
||||
Some(NetworkMode::Limited),
|
||||
"http-connect".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!(
|
||||
"CONNECT blocked; MITM required for read-only HTTPS in limited mode (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Err(blocked_text("mitm_required"));
|
||||
}
|
||||
|
||||
req.extensions_mut().insert(ProxyTarget(authority));
|
||||
req.extensions_mut().insert(mode);
|
||||
if let Some(mitm_state) = mitm_state {
|
||||
req.extensions_mut().insert(mitm_state);
|
||||
}
|
||||
|
||||
Ok((
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::empty())
|
||||
.unwrap_or_else(|_| Response::new(Body::empty())),
|
||||
req,
|
||||
))
|
||||
}
|
||||
|
||||
async fn http_connect_proxy(upgraded: Upgraded) -> Result<(), Infallible> {
|
||||
let mode = upgraded
|
||||
.extensions()
|
||||
.get::<NetworkMode>()
|
||||
.copied()
|
||||
.unwrap_or(NetworkMode::Full);
|
||||
|
||||
let Some(target) = upgraded
|
||||
.extensions()
|
||||
.get::<ProxyTarget>()
|
||||
.map(|t| t.0.clone())
|
||||
else {
|
||||
warn!("CONNECT missing proxy target");
|
||||
return Ok(());
|
||||
};
|
||||
let host = normalize_host(&target.host.to_string());
|
||||
|
||||
if mode == NetworkMode::Limited
|
||||
&& upgraded
|
||||
.extensions()
|
||||
.get::<Arc<mitm::MitmState>>()
|
||||
.is_some()
|
||||
{
|
||||
let port = target.port;
|
||||
info!("CONNECT MITM enabled (host={host}, port={port}, mode={mode:?})");
|
||||
if let Err(err) = mitm::mitm_tunnel(upgraded).await {
|
||||
warn!("MITM tunnel error: {err}");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let allow_upstream_proxy = match upgraded.extensions().get::<Arc<AppState>>().cloned() {
|
||||
Some(state) => match state.allow_upstream_proxy().await {
|
||||
Ok(allowed) => allowed,
|
||||
Err(err) => {
|
||||
error!("failed to read upstream proxy setting: {err}");
|
||||
false
|
||||
}
|
||||
},
|
||||
None => {
|
||||
error!("missing app state");
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
let proxy = if allow_upstream_proxy {
|
||||
proxy_for_connect()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Err(err) = forward_connect_tunnel(upgraded, proxy).await {
|
||||
warn!("tunnel error: {err}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn forward_connect_tunnel(
|
||||
upgraded: Upgraded,
|
||||
proxy: Option<ProxyAddress>,
|
||||
) -> Result<(), BoxError> {
|
||||
let authority = upgraded
|
||||
.extensions()
|
||||
.get::<ProxyTarget>()
|
||||
.map(|target| target.0.clone())
|
||||
.ok_or_else(|| OpaqueError::from_display("missing forward authority").into_boxed())?;
|
||||
|
||||
let mut extensions = upgraded.extensions().clone();
|
||||
if let Some(proxy) = proxy {
|
||||
extensions.insert(proxy);
|
||||
}
|
||||
|
||||
let req = TcpRequest::new_with_extensions(authority.clone(), extensions)
|
||||
.with_protocol(Protocol::HTTPS);
|
||||
let proxy_connector = HttpProxyConnector::optional(TcpConnector::new());
|
||||
let tls_config = TlsConnectorDataBuilder::new_http_auto().into_shared_builder();
|
||||
let connector = TlsConnectorLayer::tunnel(None)
|
||||
.with_connector_data(tls_config)
|
||||
.into_layer(proxy_connector);
|
||||
let EstablishedClientConnection { conn: target, .. } =
|
||||
connector.connect(req).await.map_err(|err| {
|
||||
OpaqueError::from_boxed(err)
|
||||
.with_context(|| format!("establish CONNECT tunnel to {authority}"))
|
||||
.into_boxed()
|
||||
})?;
|
||||
|
||||
let proxy_req = ProxyRequest {
|
||||
source: upgraded,
|
||||
target,
|
||||
};
|
||||
StreamForwardService::default()
|
||||
.serve(proxy_req)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
OpaqueError::from_boxed(err.into())
|
||||
.with_context(|| format!("forward CONNECT tunnel to {authority}"))
|
||||
.into_boxed()
|
||||
})
|
||||
}
|
||||
|
||||
async fn http_plain_proxy(
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
req: Request,
|
||||
) -> Result<Response, Infallible> {
|
||||
let app_state = match req.extensions().get::<Arc<AppState>>().cloned() {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
error!("missing app state");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
let client = client_addr(&req);
|
||||
|
||||
let method_allowed = match app_state.method_allowed(req.method().as_str()).await {
|
||||
Ok(allowed) => allowed,
|
||||
Err(err) => {
|
||||
error!("failed to evaluate method policy: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
|
||||
// `x-unix-socket` is an escape hatch for talking to local daemons. We keep it tightly scoped:
|
||||
// macOS-only + explicit allowlist, to avoid turning the proxy into a general local capability
|
||||
// escalation mechanism.
|
||||
if let Some(unix_socket_header) = req.headers().get("x-unix-socket") {
|
||||
let socket_path = match unix_socket_header.to_str() {
|
||||
Ok(value) => value.to_string(),
|
||||
Err(_) => {
|
||||
warn!("invalid x-unix-socket header value (non-UTF8)");
|
||||
return Ok(text_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid x-unix-socket header",
|
||||
));
|
||||
}
|
||||
};
|
||||
let enabled = match app_state.enabled().await {
|
||||
Ok(enabled) => enabled,
|
||||
Err(err) => {
|
||||
error!("failed to read enabled state: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
if !enabled {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
socket_path.clone(),
|
||||
"proxy_disabled".to_string(),
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
None,
|
||||
"unix-socket".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("unix socket blocked; proxy disabled (client={client}, path={socket_path})");
|
||||
return Ok(text_response(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"proxy disabled",
|
||||
));
|
||||
}
|
||||
if !method_allowed {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
warn!(
|
||||
"unix socket blocked by method policy (client={client}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Ok(json_blocked("unix-socket", "method_not_allowed"));
|
||||
}
|
||||
|
||||
if !cfg!(target_os = "macos") {
|
||||
warn!("unix socket proxy unsupported on this platform (path={socket_path})");
|
||||
return Ok(text_response(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"unix sockets unsupported",
|
||||
));
|
||||
}
|
||||
|
||||
match app_state.is_unix_socket_allowed(&socket_path).await {
|
||||
Ok(true) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("unix socket allowed (client={client}, path={socket_path})");
|
||||
match proxy_via_unix_socket(req, &socket_path).await {
|
||||
Ok(resp) => return Ok(resp),
|
||||
Err(err) => {
|
||||
warn!("unix socket proxy failed: {err}");
|
||||
return Ok(text_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"unix socket proxy failed",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(false) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("unix socket blocked (client={client}, path={socket_path})");
|
||||
return Ok(json_blocked("unix-socket", "not_allowed"));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("unix socket check failed: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
|
||||
Ok(authority) => authority,
|
||||
Err(err) => {
|
||||
warn!("missing host: {err}");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
|
||||
}
|
||||
};
|
||||
let host = normalize_host(&authority.host.to_string());
|
||||
let port = authority.port;
|
||||
let enabled = match app_state.enabled().await {
|
||||
Ok(enabled) => enabled,
|
||||
Err(err) => {
|
||||
error!("failed to read enabled state: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
if !enabled {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"proxy_disabled".to_string(),
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
None,
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
warn!("request blocked; proxy disabled (client={client}, host={host}, method={method})");
|
||||
return Ok(text_response(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"proxy disabled",
|
||||
));
|
||||
}
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Http,
|
||||
host.clone(),
|
||||
port,
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
|
||||
Ok(NetworkDecision::Deny { reason }) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
None,
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("request blocked (client={client}, host={host}, reason={reason})");
|
||||
return Ok(json_blocked(&host, &reason));
|
||||
}
|
||||
Ok(NetworkDecision::Allow) => {}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate host for {host}: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
}
|
||||
|
||||
if !method_allowed {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
client.clone(),
|
||||
Some(req.method().as_str().to_string()),
|
||||
Some(NetworkMode::Limited),
|
||||
"http".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
warn!(
|
||||
"request blocked by method policy (client={client}, host={host}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Ok(json_blocked(&host, "method_not_allowed"));
|
||||
}
|
||||
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
let method = req.method();
|
||||
info!("request allowed (client={client}, host={host}, method={method})");
|
||||
|
||||
let allow_upstream_proxy = match app_state.allow_upstream_proxy().await {
|
||||
Ok(allow) => allow,
|
||||
Err(err) => {
|
||||
error!("failed to read upstream proxy config: {err}");
|
||||
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
|
||||
}
|
||||
};
|
||||
let client = if allow_upstream_proxy {
|
||||
UpstreamClient::from_env_proxy()
|
||||
} else {
|
||||
UpstreamClient::direct()
|
||||
};
|
||||
|
||||
match client.serve(req).await {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(err) => {
|
||||
warn!("upstream request failed: {err}");
|
||||
Ok(text_response(StatusCode::BAD_GATEWAY, "upstream failure"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_via_unix_socket(req: Request, socket_path: &str) -> Result<Response> {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
let client = UpstreamClient::unix_socket(socket_path);
|
||||
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let path = parts
|
||||
.uri
|
||||
.path_and_query()
|
||||
.map(rama_http::uri::PathAndQuery::as_str)
|
||||
.unwrap_or("/");
|
||||
parts.uri = path
|
||||
.parse()
|
||||
.with_context(|| format!("invalid unix socket request path: {path}"))?;
|
||||
parts.headers.remove("x-unix-socket");
|
||||
|
||||
let req = Request::from_parts(parts, body);
|
||||
client.serve(req).await.map_err(anyhow::Error::from)
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _ = req;
|
||||
let _ = socket_path;
|
||||
Err(anyhow::anyhow!("unix sockets not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
fn client_addr<T: ExtensionsRef>(input: &T) -> Option<String> {
|
||||
input
|
||||
.extensions()
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string())
|
||||
}
|
||||
|
||||
fn json_blocked(host: &str, reason: &str) -> Response {
|
||||
let response = BlockedResponse {
|
||||
status: "blocked",
|
||||
host,
|
||||
reason,
|
||||
};
|
||||
let mut resp = json_response(&response);
|
||||
*resp.status_mut() = StatusCode::FORBIDDEN;
|
||||
resp.headers_mut().insert(
|
||||
"x-proxy-error",
|
||||
HeaderValue::from_static(blocked_header_value(reason)),
|
||||
);
|
||||
resp
|
||||
}
|
||||
|
||||
fn blocked_text(reason: &str) -> Response {
|
||||
crate::responses::blocked_text_response(reason)
|
||||
}
|
||||
|
||||
fn text_response(status: StatusCode, body: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct BlockedResponse<'a> {
|
||||
status: &'static str,
|
||||
host: &'a str,
|
||||
reason: &'a str,
|
||||
}
|
||||
17
codex-rs/network-proxy/src/init.rs
Normal file
17
codex-rs/network-proxy/src/init.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_core::config::find_codex_home;
|
||||
use std::fs;
|
||||
|
||||
pub fn run_init() -> Result<()> {
|
||||
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
|
||||
let root = codex_home.join("network_proxy");
|
||||
let mitm_dir = root.join("mitm");
|
||||
|
||||
fs::create_dir_all(&root).with_context(|| format!("failed to create {}", root.display()))?;
|
||||
fs::create_dir_all(&mitm_dir)
|
||||
.with_context(|| format!("failed to create {}", mitm_dir.display()))?;
|
||||
|
||||
println!("ensured {}", mitm_dir.display());
|
||||
Ok(())
|
||||
}
|
||||
35
codex-rs/network-proxy/src/lib.rs
Normal file
35
codex-rs/network-proxy/src/lib.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
mod admin;
|
||||
mod config;
|
||||
mod http_proxy;
|
||||
mod init;
|
||||
mod mitm;
|
||||
mod network_policy;
|
||||
mod policy;
|
||||
mod proxy;
|
||||
mod responses;
|
||||
mod runtime;
|
||||
mod socks5;
|
||||
mod state;
|
||||
mod upstream;
|
||||
|
||||
use anyhow::Result;
|
||||
pub use network_policy::NetworkDecision;
|
||||
pub use network_policy::NetworkPolicyDecider;
|
||||
pub use network_policy::NetworkPolicyRequest;
|
||||
pub use network_policy::NetworkProtocol;
|
||||
pub use proxy::Args;
|
||||
pub use proxy::Command;
|
||||
pub use proxy::NetworkProxy;
|
||||
pub use proxy::NetworkProxyBuilder;
|
||||
pub use proxy::NetworkProxyHandle;
|
||||
pub use proxy::run_init;
|
||||
|
||||
pub async fn run_main(args: Args) -> Result<()> {
|
||||
if let Some(Command::Init) = args.command {
|
||||
run_init()?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let proxy = NetworkProxy::from_cli_args(args).await?;
|
||||
proxy.run().await?.wait().await
|
||||
}
|
||||
13
codex-rs/network-proxy/src/main.rs
Normal file
13
codex-rs/network-proxy/src/main.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use codex_network_proxy::Args;
|
||||
use codex_network_proxy::NetworkProxy;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let args = Args::parse();
|
||||
let proxy = NetworkProxy::from_cli_args(args).await?;
|
||||
proxy.run().await?.wait().await
|
||||
}
|
||||
622
codex-rs/network-proxy/src/mitm.rs
Normal file
622
codex-rs/network-proxy/src/mitm.rs
Normal file
@@ -0,0 +1,622 @@
|
||||
use crate::config::MitmConfig;
|
||||
use crate::config::NetworkMode;
|
||||
use crate::policy::method_allowed;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::responses::blocked_text_response;
|
||||
use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use crate::upstream::UpstreamClient;
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use rama_core::Layer;
|
||||
use rama_core::Service;
|
||||
use rama_core::bytes::Bytes;
|
||||
use rama_core::error::BoxError;
|
||||
use rama_core::extensions::ExtensionsRef;
|
||||
use rama_core::futures::stream::Stream;
|
||||
use rama_core::rt::Executor;
|
||||
use rama_core::service::service_fn;
|
||||
use rama_http::Body;
|
||||
use rama_http::BodyDataStream;
|
||||
use rama_http::HeaderValue;
|
||||
use rama_http::Request;
|
||||
use rama_http::Response;
|
||||
use rama_http::StatusCode;
|
||||
use rama_http::Uri;
|
||||
use rama_http::header::HOST;
|
||||
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
|
||||
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
|
||||
use rama_http_backend::server::HttpServer;
|
||||
use rama_http_backend::server::layer::upgrade::Upgraded;
|
||||
use rama_net::proxy::ProxyTarget;
|
||||
use rama_net::stream::SocketInfo;
|
||||
use rama_net::tls::ApplicationProtocol;
|
||||
use rama_net::tls::DataEncoding;
|
||||
use rama_net::tls::server::ServerAuth;
|
||||
use rama_net::tls::server::ServerAuthData;
|
||||
use rama_net::tls::server::ServerConfig;
|
||||
use rama_tls_boring::server::TlsAcceptorData;
|
||||
use rama_tls_boring::server::TlsAcceptorLayer;
|
||||
use rama_utils::str::NonEmptyStr;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::net::IpAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Context as TaskContext;
|
||||
use std::task::Poll;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use rcgen_rama::BasicConstraints;
|
||||
use rcgen_rama::CertificateParams;
|
||||
use rcgen_rama::DistinguishedName;
|
||||
use rcgen_rama::DnType;
|
||||
use rcgen_rama::ExtendedKeyUsagePurpose;
|
||||
use rcgen_rama::IsCa;
|
||||
use rcgen_rama::Issuer;
|
||||
use rcgen_rama::KeyPair;
|
||||
use rcgen_rama::KeyUsagePurpose;
|
||||
use rcgen_rama::SanType;
|
||||
|
||||
pub struct MitmState {
|
||||
issuer: Issuer<'static, KeyPair>,
|
||||
upstream: UpstreamClient,
|
||||
inspect: bool,
|
||||
max_body_bytes: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MitmState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Avoid dumping internal state (CA material, connectors, etc.) to logs.
|
||||
f.debug_struct("MitmState")
|
||||
.field("inspect", &self.inspect)
|
||||
.field("max_body_bytes", &self.max_body_bytes)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl MitmState {
|
||||
pub fn new(cfg: &MitmConfig, allow_upstream_proxy: bool) -> Result<Self> {
|
||||
// MITM exists to make limited-mode HTTPS enforceable: once CONNECT is established, plain
|
||||
// proxying would lose visibility into the inner HTTP request. We generate/load a local CA
|
||||
// and issue per-host leaf certs so we can terminate TLS and apply policy.
|
||||
let (ca_cert_pem, ca_key_pem) = load_or_create_ca(cfg)?;
|
||||
let ca_key = KeyPair::from_pem(&ca_key_pem).context("failed to parse CA key")?;
|
||||
let issuer: Issuer<'static, KeyPair> =
|
||||
Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key).context("failed to parse CA cert")?;
|
||||
|
||||
let upstream = if allow_upstream_proxy {
|
||||
UpstreamClient::from_env_proxy()
|
||||
} else {
|
||||
UpstreamClient::direct()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
issuer,
|
||||
upstream,
|
||||
inspect: cfg.inspect,
|
||||
max_body_bytes: cfg.max_body_bytes,
|
||||
})
|
||||
}
|
||||
|
||||
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
|
||||
let (cert_pem, key_pem) = issue_host_certificate_pem(host, &self.issuer)?;
|
||||
let cert_chain = DataEncoding::Pem(
|
||||
NonEmptyStr::try_from(cert_pem.as_str()).context("failed to encode host cert PEM")?,
|
||||
);
|
||||
let private_key = DataEncoding::Pem(
|
||||
NonEmptyStr::try_from(key_pem.as_str()).context("failed to encode host key PEM")?,
|
||||
);
|
||||
let auth = ServerAuthData {
|
||||
private_key,
|
||||
cert_chain,
|
||||
ocsp: None,
|
||||
};
|
||||
|
||||
let mut server_config = ServerConfig::new(ServerAuth::Single(auth));
|
||||
server_config.application_layer_protocol_negotiation = Some(vec![
|
||||
ApplicationProtocol::HTTP_2,
|
||||
ApplicationProtocol::HTTP_11,
|
||||
]);
|
||||
|
||||
TlsAcceptorData::try_from(server_config).context("failed to build boring acceptor config")
|
||||
}
|
||||
|
||||
pub fn inspect_enabled(&self) -> bool {
|
||||
self.inspect
|
||||
}
|
||||
|
||||
pub fn max_body_bytes(&self) -> usize {
|
||||
self.max_body_bytes
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
|
||||
let state = upgraded
|
||||
.extensions()
|
||||
.get::<Arc<MitmState>>()
|
||||
.cloned()
|
||||
.context("missing MITM state")?;
|
||||
let target = upgraded
|
||||
.extensions()
|
||||
.get::<ProxyTarget>()
|
||||
.context("missing proxy target")?
|
||||
.0
|
||||
.clone();
|
||||
let host = normalize_host(&target.host.to_string());
|
||||
let acceptor_data = state.tls_acceptor_data_for_host(&host)?;
|
||||
|
||||
let executor = upgraded
|
||||
.extensions()
|
||||
.get::<Executor>()
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
let http_service = HttpServer::auto(executor).service(
|
||||
(
|
||||
RemoveResponseHeaderLayer::hop_by_hop(),
|
||||
RemoveRequestHeaderLayer::hop_by_hop(),
|
||||
)
|
||||
.into_layer(service_fn(handle_mitm_request)),
|
||||
);
|
||||
|
||||
let https_service = TlsAcceptorLayer::new(acceptor_data)
|
||||
.with_store_client_hello(true)
|
||||
.into_layer(http_service);
|
||||
|
||||
https_service
|
||||
.serve(upgraded)
|
||||
.await
|
||||
.map_err(|err| anyhow!("MITM serve error: {err}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_mitm_request(req: Request) -> Result<Response, std::convert::Infallible> {
|
||||
let response = match forward_request(req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
warn!("MITM upstream request failed: {err}");
|
||||
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
|
||||
}
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn forward_request(req: Request) -> Result<Response> {
|
||||
let target = req
|
||||
.extensions()
|
||||
.get::<ProxyTarget>()
|
||||
.context("missing proxy target")?
|
||||
.0
|
||||
.clone();
|
||||
|
||||
let target_host = normalize_host(&target.host.to_string());
|
||||
let target_port = target.port;
|
||||
let mode = req
|
||||
.extensions()
|
||||
.get::<NetworkMode>()
|
||||
.copied()
|
||||
.unwrap_or(NetworkMode::Full);
|
||||
let mitm = req
|
||||
.extensions()
|
||||
.get::<Arc<MitmState>>()
|
||||
.cloned()
|
||||
.context("missing MITM state")?;
|
||||
let app_state = req
|
||||
.extensions()
|
||||
.get::<Arc<AppState>>()
|
||||
.cloned()
|
||||
.context("missing app state")?;
|
||||
|
||||
if req.method().as_str() == "CONNECT" {
|
||||
return Ok(text_response(
|
||||
StatusCode::METHOD_NOT_ALLOWED,
|
||||
"CONNECT not supported inside MITM",
|
||||
));
|
||||
}
|
||||
|
||||
let method = req.method().as_str().to_string();
|
||||
let path = path_and_query(req.uri());
|
||||
let client = req
|
||||
.extensions()
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string());
|
||||
|
||||
if let Some(request_host) = extract_request_host(&req) {
|
||||
let normalized = normalize_host(&request_host);
|
||||
if !normalized.is_empty() && normalized != target_host {
|
||||
warn!("MITM host mismatch (target={target_host}, request_host={normalized})");
|
||||
return Ok(text_response(StatusCode::BAD_REQUEST, "host mismatch"));
|
||||
}
|
||||
}
|
||||
|
||||
if !method_allowed(mode, method.as_str()) {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
target_host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
client.clone(),
|
||||
Some(method.clone()),
|
||||
Some(NetworkMode::Limited),
|
||||
"https".to_string(),
|
||||
))
|
||||
.await;
|
||||
warn!(
|
||||
"MITM blocked by method policy (host={target_host}, method={method}, path={path}, mode={mode:?}, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Ok(blocked_text("method_not_allowed"));
|
||||
}
|
||||
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let authority = authority_header_value(&target_host, target_port);
|
||||
parts.uri = build_https_uri(&authority, &path)?;
|
||||
parts
|
||||
.headers
|
||||
.insert(HOST, HeaderValue::from_str(&authority)?);
|
||||
|
||||
let inspect = mitm.inspect_enabled();
|
||||
let max_body_bytes = mitm.max_body_bytes();
|
||||
let body = if inspect {
|
||||
inspect_body(
|
||||
body,
|
||||
max_body_bytes,
|
||||
RequestLogContext {
|
||||
host: authority.clone(),
|
||||
method: method.clone(),
|
||||
path: path.clone(),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
body
|
||||
};
|
||||
|
||||
let upstream_req = Request::from_parts(parts, body);
|
||||
let upstream_resp = mitm.upstream.serve(upstream_req).await?;
|
||||
respond_with_inspection(
|
||||
upstream_resp,
|
||||
inspect,
|
||||
max_body_bytes,
|
||||
&method,
|
||||
&path,
|
||||
&authority,
|
||||
)
|
||||
}
|
||||
|
||||
fn respond_with_inspection(
|
||||
resp: Response,
|
||||
inspect: bool,
|
||||
max_body_bytes: usize,
|
||||
method: &str,
|
||||
path: &str,
|
||||
authority: &str,
|
||||
) -> Result<Response> {
|
||||
if !inspect {
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
let (parts, body) = resp.into_parts();
|
||||
let body = inspect_body(
|
||||
body,
|
||||
max_body_bytes,
|
||||
ResponseLogContext {
|
||||
host: authority.to_string(),
|
||||
method: method.to_string(),
|
||||
path: path.to_string(),
|
||||
status: parts.status,
|
||||
},
|
||||
);
|
||||
Ok(Response::from_parts(parts, body))
|
||||
}
|
||||
|
||||
fn inspect_body<T: BodyLoggable + Send + 'static>(
|
||||
body: Body,
|
||||
max_body_bytes: usize,
|
||||
ctx: T,
|
||||
) -> Body {
|
||||
Body::from_stream(InspectStream {
|
||||
inner: Box::pin(body.into_data_stream()),
|
||||
ctx: Some(Box::new(ctx)),
|
||||
len: 0,
|
||||
max_body_bytes,
|
||||
})
|
||||
}
|
||||
|
||||
struct InspectStream<T> {
|
||||
inner: Pin<Box<BodyDataStream>>,
|
||||
ctx: Option<Box<T>>,
|
||||
len: usize,
|
||||
max_body_bytes: usize,
|
||||
}
|
||||
|
||||
impl<T: BodyLoggable> Stream for InspectStream<T> {
|
||||
type Item = Result<Bytes, BoxError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
match this.inner.as_mut().poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(bytes))) => {
|
||||
this.len = this.len.saturating_add(bytes.len());
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(None) => {
|
||||
if let Some(ctx) = this.ctx.take() {
|
||||
ctx.log(this.len, this.len > this.max_body_bytes);
|
||||
}
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RequestLogContext {
|
||||
host: String,
|
||||
method: String,
|
||||
path: String,
|
||||
}
|
||||
|
||||
struct ResponseLogContext {
|
||||
host: String,
|
||||
method: String,
|
||||
path: String,
|
||||
status: StatusCode,
|
||||
}
|
||||
|
||||
trait BodyLoggable {
|
||||
fn log(self, len: usize, truncated: bool);
|
||||
}
|
||||
|
||||
impl BodyLoggable for RequestLogContext {
|
||||
fn log(self, len: usize, truncated: bool) {
|
||||
let host = self.host;
|
||||
let method = self.method;
|
||||
let path = self.path;
|
||||
info!(
|
||||
"MITM inspected request body (host={host}, method={method}, path={path}, body_len={len}, truncated={truncated})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl BodyLoggable for ResponseLogContext {
|
||||
fn log(self, len: usize, truncated: bool) {
|
||||
let host = self.host;
|
||||
let method = self.method;
|
||||
let path = self.path;
|
||||
let status = self.status;
|
||||
info!(
|
||||
"MITM inspected response body (host={host}, method={method}, path={path}, status={status}, body_len={len}, truncated={truncated})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_request_host(req: &Request) -> Option<String> {
|
||||
req.headers()
|
||||
.get(HOST)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(ToString::to_string)
|
||||
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
|
||||
}
|
||||
|
||||
fn authority_header_value(host: &str, port: u16) -> String {
|
||||
// Host header / URI authority formatting.
|
||||
if host.contains(':') {
|
||||
if port == 443 {
|
||||
format!("[{host}]")
|
||||
} else {
|
||||
format!("[{host}]:{port}")
|
||||
}
|
||||
} else if port == 443 {
|
||||
host.to_string()
|
||||
} else {
|
||||
format!("{host}:{port}")
|
||||
}
|
||||
}
|
||||
|
||||
fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
|
||||
let target = format!("https://{authority}{path}");
|
||||
Ok(target.parse()?)
|
||||
}
|
||||
|
||||
fn path_and_query(uri: &Uri) -> String {
|
||||
uri.path_and_query()
|
||||
.map(rama_http::uri::PathAndQuery::as_str)
|
||||
.unwrap_or("/")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn issue_host_certificate_pem(
|
||||
host: &str,
|
||||
issuer: &Issuer<'_, KeyPair>,
|
||||
) -> Result<(String, String)> {
|
||||
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
let mut params = CertificateParams::new(Vec::new())
|
||||
.map_err(|err| anyhow!("failed to create cert params: {err}"))?;
|
||||
params.subject_alt_names.push(SanType::IpAddress(ip));
|
||||
params
|
||||
} else {
|
||||
CertificateParams::new(vec![host.to_string()])
|
||||
.map_err(|err| anyhow!("failed to create cert params: {err}"))?
|
||||
};
|
||||
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
|
||||
let key_pair = KeyPair::generate_for(&rcgen_rama::PKCS_ECDSA_P256_SHA256)
|
||||
.map_err(|err| anyhow!("failed to generate host key pair: {err}"))?;
|
||||
let cert = params
|
||||
.signed_by(&key_pair, issuer)
|
||||
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
|
||||
|
||||
Ok((cert.pem(), key_pair.serialize_pem()))
|
||||
}
|
||||
|
||||
fn load_or_create_ca(cfg: &MitmConfig) -> Result<(String, String)> {
|
||||
let cert_path = &cfg.ca_cert_path;
|
||||
let key_path = &cfg.ca_key_path;
|
||||
|
||||
if cert_path.exists() || key_path.exists() {
|
||||
if !cert_path.exists() || !key_path.exists() {
|
||||
return Err(anyhow!("both ca_cert_path and ca_key_path must exist"));
|
||||
}
|
||||
let cert_pem = fs::read_to_string(cert_path)
|
||||
.with_context(|| format!("failed to read CA cert {}", cert_path.display()))?;
|
||||
let key_pem = fs::read_to_string(key_path)
|
||||
.with_context(|| format!("failed to read CA key {}", key_path.display()))?;
|
||||
return Ok((cert_pem, key_pem));
|
||||
}
|
||||
|
||||
if let Some(parent) = cert_path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("failed to create {}", parent.display()))?;
|
||||
}
|
||||
if let Some(parent) = key_path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("failed to create {}", parent.display()))?;
|
||||
}
|
||||
|
||||
let (cert_pem, key_pem) = generate_ca()?;
|
||||
// The CA key is a high-value secret. Create it atomically with restrictive permissions.
|
||||
// The cert can be world-readable, but we still write it atomically to avoid partial writes.
|
||||
//
|
||||
// We intentionally use create-new semantics: if a key already exists, we should not overwrite
|
||||
// it silently (that would invalidate previously-trusted cert chains).
|
||||
write_atomic_create_new(key_path, key_pem.as_bytes(), 0o600)
|
||||
.with_context(|| format!("failed to persist CA key {}", key_path.display()))?;
|
||||
if let Err(err) = write_atomic_create_new(cert_path, cert_pem.as_bytes(), 0o644)
|
||||
.with_context(|| format!("failed to persist CA cert {}", cert_path.display()))
|
||||
{
|
||||
// Avoid leaving a partially-created CA around (cert missing) if the second write fails.
|
||||
let _ = fs::remove_file(key_path);
|
||||
return Err(err);
|
||||
}
|
||||
let cert_path = cert_path.display();
|
||||
let key_path = key_path.display();
|
||||
info!("generated MITM CA (cert_path={cert_path}, key_path={key_path})");
|
||||
Ok((cert_pem, key_pem))
|
||||
}
|
||||
|
||||
fn generate_ca() -> Result<(String, String)> {
|
||||
let mut params = CertificateParams::default();
|
||||
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::KeyCertSign,
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
let mut dn = DistinguishedName::new();
|
||||
dn.push(DnType::CommonName, "network_proxy MITM CA");
|
||||
params.distinguished_name = dn;
|
||||
|
||||
let key_pair = KeyPair::generate_for(&rcgen_rama::PKCS_ECDSA_P256_SHA256)
|
||||
.map_err(|err| anyhow!("failed to generate CA key pair: {err}"))?;
|
||||
let cert = params
|
||||
.self_signed(&key_pair)
|
||||
.map_err(|err| anyhow!("failed to generate CA cert: {err}"))?;
|
||||
Ok((cert.pem(), key_pair.serialize_pem()))
|
||||
}
|
||||
|
||||
fn write_atomic_create_new(path: &std::path::Path, contents: &[u8], mode: u32) -> Result<()> {
|
||||
let parent = path
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow!("missing parent directory"))?;
|
||||
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
let pid = std::process::id();
|
||||
let file_name = path.file_name().unwrap_or_default().to_string_lossy();
|
||||
let tmp_path = parent.join(format!(".{file_name}.tmp.{pid}.{nanos}"));
|
||||
|
||||
let mut file = open_create_new_with_mode(&tmp_path, mode)?;
|
||||
file.write_all(contents)
|
||||
.with_context(|| format!("failed to write {}", tmp_path.display()))?;
|
||||
file.sync_all()
|
||||
.with_context(|| format!("failed to fsync {}", tmp_path.display()))?;
|
||||
drop(file);
|
||||
|
||||
// Create the final file using "create-new" semantics (no overwrite). `rename` on Unix can
|
||||
// overwrite existing files, so prefer a hard-link, which fails if the destination exists.
|
||||
match fs::hard_link(&tmp_path, path) {
|
||||
Ok(()) => {
|
||||
fs::remove_file(&tmp_path)
|
||||
.with_context(|| format!("failed to remove {}", tmp_path.display()))?;
|
||||
}
|
||||
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
|
||||
let _ = fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"refusing to overwrite existing file {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
// Best-effort fallback for environments where hard links are not supported.
|
||||
// This is still subject to a TOCTOU race, but the typical case is a private per-user
|
||||
// config directory, where other users cannot create files anyway.
|
||||
if path.exists() {
|
||||
let _ = fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"refusing to overwrite existing file {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
fs::rename(&tmp_path, path).with_context(|| {
|
||||
format!(
|
||||
"failed to rename {} -> {}",
|
||||
tmp_path.display(),
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
// Best-effort durability: ensure the directory entry is persisted too.
|
||||
let dir = File::open(parent).with_context(|| format!("failed to open {}", parent.display()))?;
|
||||
dir.sync_all()
|
||||
.with_context(|| format!("failed to fsync {}", parent.display()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn open_create_new_with_mode(path: &std::path::Path, mode: u32) -> Result<File> {
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
|
||||
OpenOptions::new()
|
||||
.write(true)
|
||||
.create_new(true)
|
||||
.mode(mode)
|
||||
.open(path)
|
||||
.with_context(|| format!("failed to create {}", path.display()))
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn open_create_new_with_mode(path: &std::path::Path, _mode: u32) -> Result<File> {
|
||||
OpenOptions::new()
|
||||
.write(true)
|
||||
.create_new(true)
|
||||
.open(path)
|
||||
.with_context(|| format!("failed to create {}", path.display()))
|
||||
}
|
||||
|
||||
fn blocked_text(reason: &str) -> Response {
|
||||
blocked_text_response(reason)
|
||||
}
|
||||
|
||||
fn text_response(status: StatusCode, body: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
|
||||
}
|
||||
230
codex-rs/network-proxy/src/network_policy.rs
Normal file
230
codex-rs/network-proxy/src/network_policy.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
use crate::state::AppState;
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum NetworkProtocol {
|
||||
Http,
|
||||
HttpsConnect,
|
||||
Socks5Tcp,
|
||||
Socks5Udp,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NetworkPolicyRequest {
|
||||
pub protocol: NetworkProtocol,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub client_addr: Option<String>,
|
||||
pub method: Option<String>,
|
||||
pub command: Option<String>,
|
||||
pub exec_policy_hint: Option<String>,
|
||||
}
|
||||
|
||||
impl NetworkPolicyRequest {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
protocol: NetworkProtocol,
|
||||
host: String,
|
||||
port: u16,
|
||||
client_addr: Option<String>,
|
||||
method: Option<String>,
|
||||
command: Option<String>,
|
||||
exec_policy_hint: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
protocol,
|
||||
host,
|
||||
port,
|
||||
client_addr,
|
||||
method,
|
||||
command,
|
||||
exec_policy_hint,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum NetworkDecision {
|
||||
Allow,
|
||||
Deny { reason: String },
|
||||
}
|
||||
|
||||
impl NetworkDecision {
|
||||
#[must_use]
|
||||
pub fn deny(reason: impl Into<String>) -> Self {
|
||||
let reason = reason.into();
|
||||
let reason = if reason.is_empty() {
|
||||
"policy_denied".to_string()
|
||||
} else {
|
||||
reason
|
||||
};
|
||||
Self::Deny { reason }
|
||||
}
|
||||
}
|
||||
|
||||
/// Decide whether a network request should be allowed.
|
||||
///
|
||||
/// If `command` or `exec_policy_hint` is provided, callers can map exec-policy
|
||||
/// approvals to network access (e.g., allow all requests for commands matching
|
||||
/// approved prefixes like `curl *`).
|
||||
#[async_trait]
|
||||
pub trait NetworkPolicyDecider: Send + Sync + 'static {
|
||||
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<D: NetworkPolicyDecider + ?Sized> NetworkPolicyDecider for Arc<D> {
|
||||
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision {
|
||||
(**self).decide(req).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, Fut> NetworkPolicyDecider for F
|
||||
where
|
||||
F: Fn(NetworkPolicyRequest) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = NetworkDecision> + Send,
|
||||
{
|
||||
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision {
|
||||
(self)(req).await
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn evaluate_host_policy(
|
||||
state: &AppState,
|
||||
decider: Option<&Arc<dyn NetworkPolicyDecider>>,
|
||||
request: &NetworkPolicyRequest,
|
||||
) -> Result<NetworkDecision> {
|
||||
let (blocked, reason) = state.host_blocked(&request.host, request.port).await?;
|
||||
if !blocked {
|
||||
return Ok(NetworkDecision::Allow);
|
||||
}
|
||||
|
||||
if reason == "not_allowed"
|
||||
&& let Some(decider) = decider
|
||||
{
|
||||
return Ok(decider.decide(request.clone()).await);
|
||||
}
|
||||
|
||||
Ok(NetworkDecision::deny(reason))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::config::NetworkPolicy;
|
||||
use crate::state::app_state_for_policy;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
#[tokio::test]
|
||||
async fn evaluate_host_policy_invokes_decider_for_not_allowed() {
|
||||
let state = app_state_for_policy(NetworkPolicy::default());
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
|
||||
let calls = calls.clone();
|
||||
move |_req| {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
async { NetworkDecision::Allow }
|
||||
}
|
||||
});
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Http,
|
||||
"example.com".to_string(),
|
||||
80,
|
||||
None,
|
||||
Some("GET".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let decision = evaluate_host_policy(&state, Some(&decider), &request)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(decision, NetworkDecision::Allow);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn evaluate_host_policy_skips_decider_for_denied() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
denied_domains: vec!["blocked.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
|
||||
let calls = calls.clone();
|
||||
move |_req| {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
async { NetworkDecision::Allow }
|
||||
}
|
||||
});
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Http,
|
||||
"blocked.com".to_string(),
|
||||
80,
|
||||
None,
|
||||
Some("GET".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let decision = evaluate_host_policy(&state, Some(&decider), &request)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
decision,
|
||||
NetworkDecision::Deny {
|
||||
reason: "denied".to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn evaluate_host_policy_skips_decider_for_not_allowed_local() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
|
||||
let calls = calls.clone();
|
||||
move |_req| {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
async { NetworkDecision::Allow }
|
||||
}
|
||||
});
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Http,
|
||||
"127.0.0.1".to_string(),
|
||||
80,
|
||||
None,
|
||||
Some("GET".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let decision = evaluate_host_policy(&state, Some(&decider), &request)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
decision,
|
||||
NetworkDecision::Deny {
|
||||
reason: "not_allowed_local".to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
}
|
||||
338
codex-rs/network-proxy/src/policy.rs
Normal file
338
codex-rs/network-proxy/src/policy.rs
Normal file
@@ -0,0 +1,338 @@
|
||||
use crate::config::NetworkMode;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use globset::GlobBuilder;
|
||||
use globset::GlobSet;
|
||||
use globset::GlobSetBuilder;
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
pub fn method_allowed(mode: NetworkMode, method: &str) -> bool {
|
||||
match mode {
|
||||
NetworkMode::Full => true,
|
||||
NetworkMode::Limited => matches!(method, "GET" | "HEAD" | "OPTIONS"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_loopback_host(host: &str) -> bool {
|
||||
let host = host.to_ascii_lowercase();
|
||||
if host == "localhost" || host == "localhost." {
|
||||
return true;
|
||||
}
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return ip.is_loopback();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn is_non_public_ip(ip: IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ip) => is_non_public_ipv4(ip),
|
||||
IpAddr::V6(ip) => is_non_public_ipv6(ip),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_non_public_ipv4(ip: Ipv4Addr) -> bool {
|
||||
// Use the standard library classification helpers where possible; they encode the intent more
|
||||
// clearly than hand-rolled range checks.
|
||||
ip.is_loopback()
|
||||
|| ip.is_private()
|
||||
|| ip.is_link_local()
|
||||
|| ip.is_unspecified()
|
||||
|| ip.is_multicast()
|
||||
}
|
||||
|
||||
fn is_non_public_ipv6(ip: Ipv6Addr) -> bool {
|
||||
if let Some(v4) = ip.to_ipv4() {
|
||||
return is_non_public_ipv4(v4) || ip.is_loopback();
|
||||
}
|
||||
// Treat anything that isn't globally routable as "local" for SSRF prevention. In particular:
|
||||
// - `::1` loopback
|
||||
// - `fc00::/7` unique-local (RFC 4193)
|
||||
// - `fe80::/10` link-local
|
||||
// - `::` unspecified
|
||||
// - multicast ranges
|
||||
ip.is_loopback()
|
||||
|| ip.is_unspecified()
|
||||
|| ip.is_multicast()
|
||||
|| ip.is_unique_local()
|
||||
|| ip.is_unicast_link_local()
|
||||
}
|
||||
|
||||
pub fn normalize_host(host: &str) -> String {
|
||||
let host = host.trim();
|
||||
if host.starts_with('[')
|
||||
&& let Some(end) = host.find(']')
|
||||
{
|
||||
return normalize_dns_host(&host[1..end]);
|
||||
}
|
||||
|
||||
// The proxy stack should typically hand us a host without a port, but be
|
||||
// defensive and strip `:port` when there is exactly one `:`.
|
||||
if host.bytes().filter(|b| *b == b':').count() == 1 {
|
||||
let host = host.split(':').next().unwrap_or_default();
|
||||
return normalize_dns_host(host);
|
||||
}
|
||||
|
||||
// Avoid mangling unbracketed IPv6 literals, but strip trailing dots so fully qualified domain
|
||||
// names are treated the same as their dotless variants.
|
||||
normalize_dns_host(host)
|
||||
}
|
||||
|
||||
fn normalize_dns_host(host: &str) -> String {
|
||||
let host = host.to_ascii_lowercase();
|
||||
host.trim_end_matches('.').to_string()
|
||||
}
|
||||
|
||||
fn normalize_pattern(pattern: &str) -> String {
|
||||
let pattern = pattern.trim();
|
||||
if pattern == "*" {
|
||||
return "*".to_string();
|
||||
}
|
||||
|
||||
let (prefix, remainder) = if let Some(domain) = pattern.strip_prefix("**.") {
|
||||
("**.", domain)
|
||||
} else if let Some(domain) = pattern.strip_prefix("*.") {
|
||||
("*.", domain)
|
||||
} else {
|
||||
("", pattern)
|
||||
};
|
||||
|
||||
let remainder = normalize_host(remainder);
|
||||
if prefix.is_empty() {
|
||||
remainder
|
||||
} else {
|
||||
format!("{prefix}{remainder}")
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn compile_globset(patterns: &[String]) -> Result<GlobSet> {
|
||||
let mut builder = GlobSetBuilder::new();
|
||||
let mut seen = HashSet::new();
|
||||
for pattern in patterns {
|
||||
let pattern = normalize_pattern(pattern);
|
||||
// Supported domain patterns:
|
||||
// - "example.com": match the exact host
|
||||
// - "*.example.com": match any subdomain (not the apex)
|
||||
// - "**.example.com": match the apex and any subdomain
|
||||
// - "*": match any host
|
||||
for candidate in expand_domain_pattern(&pattern) {
|
||||
if !seen.insert(candidate.clone()) {
|
||||
continue;
|
||||
}
|
||||
let glob = GlobBuilder::new(&candidate)
|
||||
.case_insensitive(true)
|
||||
.build()
|
||||
.with_context(|| format!("invalid glob pattern: {candidate}"))?;
|
||||
builder.add(glob);
|
||||
}
|
||||
}
|
||||
Ok(builder.build()?)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum DomainPattern {
|
||||
Any,
|
||||
ApexAndSubdomains(String),
|
||||
SubdomainsOnly(String),
|
||||
Exact(String),
|
||||
}
|
||||
|
||||
impl DomainPattern {
|
||||
pub(crate) fn parse(input: &str) -> Self {
|
||||
if input == "*" {
|
||||
Self::Any
|
||||
} else if let Some(domain) = input.strip_prefix("**.") {
|
||||
Self::ApexAndSubdomains(domain.to_string())
|
||||
} else if let Some(domain) = input.strip_prefix("*.") {
|
||||
Self::SubdomainsOnly(domain.to_string())
|
||||
} else {
|
||||
Self::Exact(input.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn allows(&self, candidate: &DomainPattern) -> bool {
|
||||
match self {
|
||||
DomainPattern::Any => true,
|
||||
DomainPattern::Exact(domain) => match candidate {
|
||||
DomainPattern::Exact(candidate) => domain_eq(candidate, domain),
|
||||
_ => false,
|
||||
},
|
||||
DomainPattern::SubdomainsOnly(domain) => match candidate {
|
||||
DomainPattern::Any => false,
|
||||
DomainPattern::Exact(candidate) => is_strict_subdomain(candidate, domain),
|
||||
DomainPattern::SubdomainsOnly(candidate) => {
|
||||
is_subdomain_or_equal(candidate, domain)
|
||||
}
|
||||
DomainPattern::ApexAndSubdomains(candidate) => {
|
||||
is_strict_subdomain(candidate, domain)
|
||||
}
|
||||
},
|
||||
DomainPattern::ApexAndSubdomains(domain) => match candidate {
|
||||
DomainPattern::Any => false,
|
||||
DomainPattern::Exact(candidate) => is_subdomain_or_equal(candidate, domain),
|
||||
DomainPattern::SubdomainsOnly(candidate) => {
|
||||
is_subdomain_or_equal(candidate, domain)
|
||||
}
|
||||
DomainPattern::ApexAndSubdomains(candidate) => {
|
||||
is_subdomain_or_equal(candidate, domain)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_domain_pattern(pattern: &str) -> Vec<String> {
|
||||
match DomainPattern::parse(pattern) {
|
||||
DomainPattern::Any => vec![pattern.to_string()],
|
||||
DomainPattern::Exact(domain) => vec![domain],
|
||||
DomainPattern::SubdomainsOnly(domain) => {
|
||||
vec![format!("?*.{domain}")]
|
||||
}
|
||||
DomainPattern::ApexAndSubdomains(domain) => {
|
||||
vec![domain.clone(), format!("?*.{domain}")]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_domain(domain: &str) -> String {
|
||||
domain.trim_end_matches('.').to_ascii_lowercase()
|
||||
}
|
||||
|
||||
fn domain_eq(left: &str, right: &str) -> bool {
|
||||
normalize_domain(left) == normalize_domain(right)
|
||||
}
|
||||
|
||||
fn is_subdomain_or_equal(child: &str, parent: &str) -> bool {
|
||||
let child = normalize_domain(child);
|
||||
let parent = normalize_domain(parent);
|
||||
if child == parent {
|
||||
return true;
|
||||
}
|
||||
child.ends_with(&format!(".{parent}"))
|
||||
}
|
||||
|
||||
fn is_strict_subdomain(child: &str, parent: &str) -> bool {
|
||||
let child = normalize_domain(child);
|
||||
let parent = normalize_domain(parent);
|
||||
child != parent && child.ends_with(&format!(".{parent}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn method_allowed_full_allows_everything() {
|
||||
assert!(method_allowed(NetworkMode::Full, "GET"));
|
||||
assert!(method_allowed(NetworkMode::Full, "POST"));
|
||||
assert!(method_allowed(NetworkMode::Full, "CONNECT"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn method_allowed_limited_allows_only_safe_methods() {
|
||||
assert!(method_allowed(NetworkMode::Limited, "GET"));
|
||||
assert!(method_allowed(NetworkMode::Limited, "HEAD"));
|
||||
assert!(method_allowed(NetworkMode::Limited, "OPTIONS"));
|
||||
assert!(!method_allowed(NetworkMode::Limited, "POST"));
|
||||
assert!(!method_allowed(NetworkMode::Limited, "CONNECT"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_normalizes_trailing_dots() {
|
||||
let set = compile_globset(&["Example.COM.".to_string()]).unwrap();
|
||||
|
||||
assert_eq!(true, set.is_match("example.com"));
|
||||
assert_eq!(false, set.is_match("api.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_normalizes_wildcards() {
|
||||
let set = compile_globset(&["*.Example.COM.".to_string()]).unwrap();
|
||||
|
||||
assert_eq!(true, set.is_match("api.example.com"));
|
||||
assert_eq!(false, set.is_match("example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_normalizes_apex_and_subdomains() {
|
||||
let set = compile_globset(&["**.Example.COM.".to_string()]).unwrap();
|
||||
|
||||
assert_eq!(true, set.is_match("example.com"));
|
||||
assert_eq!(true, set.is_match("api.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_normalizes_bracketed_ipv6_literals() {
|
||||
let set = compile_globset(&["[::1]".to_string()]).unwrap();
|
||||
|
||||
assert_eq!(true, set.is_match("::1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_loopback_host_handles_localhost_variants() {
|
||||
assert!(is_loopback_host("localhost"));
|
||||
assert!(is_loopback_host("localhost."));
|
||||
assert!(is_loopback_host("LOCALHOST"));
|
||||
assert!(!is_loopback_host("notlocalhost"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_loopback_host_handles_ip_literals() {
|
||||
assert!(is_loopback_host("127.0.0.1"));
|
||||
assert!(is_loopback_host("::1"));
|
||||
assert!(!is_loopback_host("1.2.3.4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_non_public_ip_rejects_private_and_loopback_ranges() {
|
||||
assert!(is_non_public_ip("127.0.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("10.0.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("192.168.0.1".parse().unwrap()));
|
||||
assert!(!is_non_public_ip("8.8.8.8".parse().unwrap()));
|
||||
|
||||
assert!(is_non_public_ip("::ffff:127.0.0.1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("::ffff:10.0.0.1".parse().unwrap()));
|
||||
assert!(!is_non_public_ip("::ffff:8.8.8.8".parse().unwrap()));
|
||||
|
||||
assert!(is_non_public_ip("::1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("fe80::1".parse().unwrap()));
|
||||
assert!(is_non_public_ip("fc00::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_lowercases_and_trims() {
|
||||
assert_eq!(normalize_host(" ExAmPlE.CoM "), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_port_for_host_port() {
|
||||
assert_eq!(normalize_host("example.com:1234"), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_preserves_unbracketed_ipv6() {
|
||||
assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_trailing_dot() {
|
||||
assert_eq!(normalize_host("example.com."), "example.com");
|
||||
assert_eq!(normalize_host("ExAmPlE.CoM."), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_trailing_dot_with_port() {
|
||||
assert_eq!(normalize_host("example.com.:443"), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_host_strips_brackets_for_ipv6() {
|
||||
assert_eq!(normalize_host("[::1]"), "::1");
|
||||
assert_eq!(normalize_host("[::1]:443"), "::1");
|
||||
}
|
||||
}
|
||||
202
codex-rs/network-proxy/src/proxy.rs
Normal file
202
codex-rs/network-proxy/src/proxy.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use crate::admin;
|
||||
use crate::config;
|
||||
use crate::http_proxy;
|
||||
use crate::init;
|
||||
use crate::network_policy::NetworkPolicyDecider;
|
||||
use crate::socks5;
|
||||
use crate::state::AppState;
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use clap::Subcommand;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug, Clone, Parser)]
|
||||
#[command(name = "codex-network-proxy", about = "Codex network sandbox proxy")]
|
||||
pub struct Args {
|
||||
#[command(subcommand)]
|
||||
pub command: Option<Command>,
|
||||
/// Enable SOCKS5 UDP associate support (default: disabled).
|
||||
#[arg(long, default_value_t = false)]
|
||||
pub enable_socks5_udp: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Subcommand)]
|
||||
pub enum Command {
|
||||
/// Initialize the Codex network proxy directories (e.g. MITM cert paths).
|
||||
Init,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct NetworkProxyBuilder {
|
||||
state: Option<Arc<AppState>>,
|
||||
http_addr: Option<SocketAddr>,
|
||||
socks_addr: Option<SocketAddr>,
|
||||
admin_addr: Option<SocketAddr>,
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
enable_socks5_udp: bool,
|
||||
}
|
||||
|
||||
impl NetworkProxyBuilder {
|
||||
#[must_use]
|
||||
pub fn state(mut self, state: Arc<AppState>) -> Self {
|
||||
self.state = Some(state);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn http_addr(mut self, addr: SocketAddr) -> Self {
|
||||
self.http_addr = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn socks_addr(mut self, addr: SocketAddr) -> Self {
|
||||
self.socks_addr = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn admin_addr(mut self, addr: SocketAddr) -> Self {
|
||||
self.admin_addr = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn policy_decider<D>(mut self, decider: D) -> Self
|
||||
where
|
||||
D: NetworkPolicyDecider,
|
||||
{
|
||||
self.policy_decider = Some(Arc::new(decider));
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn policy_decider_arc(mut self, decider: Arc<dyn NetworkPolicyDecider>) -> Self {
|
||||
self.policy_decider = Some(decider);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn enable_socks5_udp(mut self, enabled: bool) -> Self {
|
||||
self.enable_socks5_udp = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn build(self) -> Result<NetworkProxy> {
|
||||
let state = match self.state {
|
||||
Some(state) => state,
|
||||
None => Arc::new(AppState::new().await?),
|
||||
};
|
||||
let runtime = config::resolve_runtime(&state.current_cfg().await?);
|
||||
let current_cfg = state.current_cfg().await?;
|
||||
// Reapply bind clamping for caller overrides so unix-socket proxying stays loopback-only.
|
||||
let (http_addr, admin_addr) = config::clamp_bind_addrs(
|
||||
self.http_addr.unwrap_or(runtime.http_addr),
|
||||
self.admin_addr.unwrap_or(runtime.admin_addr),
|
||||
¤t_cfg.network_proxy,
|
||||
);
|
||||
|
||||
Ok(NetworkProxy {
|
||||
state,
|
||||
http_addr,
|
||||
socks_addr: self.socks_addr.unwrap_or(runtime.socks_addr),
|
||||
admin_addr,
|
||||
policy_decider: self.policy_decider,
|
||||
enable_socks5_udp: self.enable_socks5_udp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NetworkProxy {
|
||||
state: Arc<AppState>,
|
||||
http_addr: SocketAddr,
|
||||
socks_addr: SocketAddr,
|
||||
admin_addr: SocketAddr,
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
enable_socks5_udp: bool,
|
||||
}
|
||||
|
||||
impl NetworkProxy {
|
||||
#[must_use]
|
||||
pub fn builder() -> NetworkProxyBuilder {
|
||||
NetworkProxyBuilder::default()
|
||||
}
|
||||
|
||||
pub async fn from_cli_args(args: Args) -> Result<Self> {
|
||||
let mut builder = Self::builder();
|
||||
builder = builder.enable_socks5_udp(args.enable_socks5_udp);
|
||||
builder.build().await
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<NetworkProxyHandle> {
|
||||
let current_cfg = self.state.current_cfg().await?;
|
||||
if !current_cfg.network_proxy.enabled {
|
||||
warn!("network_proxy.enabled is false; skipping proxy listeners");
|
||||
return Ok(NetworkProxyHandle::noop());
|
||||
}
|
||||
|
||||
if cfg!(not(target_os = "macos")) {
|
||||
warn!("allowUnixSockets is macOS-only; requests will be rejected on this platform");
|
||||
}
|
||||
|
||||
let http_task = tokio::spawn(http_proxy::run_http_proxy(
|
||||
self.state.clone(),
|
||||
self.http_addr,
|
||||
self.policy_decider.clone(),
|
||||
));
|
||||
let socks_task = tokio::spawn(socks5::run_socks5(
|
||||
self.state.clone(),
|
||||
self.socks_addr,
|
||||
self.policy_decider.clone(),
|
||||
self.enable_socks5_udp,
|
||||
));
|
||||
let admin_task = tokio::spawn(admin::run_admin_api(self.state.clone(), self.admin_addr));
|
||||
|
||||
Ok(NetworkProxyHandle {
|
||||
http_task,
|
||||
socks_task,
|
||||
admin_task,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NetworkProxyHandle {
|
||||
http_task: JoinHandle<Result<()>>,
|
||||
socks_task: JoinHandle<Result<()>>,
|
||||
admin_task: JoinHandle<Result<()>>,
|
||||
}
|
||||
|
||||
impl NetworkProxyHandle {
|
||||
fn noop() -> Self {
|
||||
Self {
|
||||
http_task: tokio::spawn(async { Ok(()) }),
|
||||
socks_task: tokio::spawn(async { Ok(()) }),
|
||||
admin_task: tokio::spawn(async { Ok(()) }),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wait(self) -> Result<()> {
|
||||
self.http_task.await??;
|
||||
self.socks_task.await??;
|
||||
self.admin_task.await??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn shutdown(self) -> Result<()> {
|
||||
self.http_task.abort();
|
||||
self.socks_task.abort();
|
||||
self.admin_task.abort();
|
||||
let _ = self.http_task.await;
|
||||
let _ = self.socks_task.await;
|
||||
let _ = self.admin_task.await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_init() -> Result<()> {
|
||||
init::run_init()
|
||||
}
|
||||
54
codex-rs/network-proxy/src/responses.rs
Normal file
54
codex-rs/network-proxy/src/responses.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use rama_http::Body;
|
||||
use rama_http::Response;
|
||||
use rama_http::StatusCode;
|
||||
use serde::Serialize;
|
||||
|
||||
pub fn text_response(status: StatusCode, body: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
|
||||
}
|
||||
|
||||
pub fn json_response<T: Serialize>(value: &T) -> Response {
|
||||
let body = match serde_json::to_string(value) {
|
||||
Ok(body) => body,
|
||||
Err(_) => "{}".to_string(),
|
||||
};
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(body))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("{}")))
|
||||
}
|
||||
|
||||
pub fn blocked_header_value(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
|
||||
"denied" => "blocked-by-denylist",
|
||||
"method_not_allowed" => "blocked-by-method-policy",
|
||||
"mitm_required" => "blocked-by-mitm-required",
|
||||
_ => "blocked-by-policy",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn blocked_message(reason: &str) -> &'static str {
|
||||
match reason {
|
||||
"not_allowed" => "Codex blocked this request: domain not in allowlist.",
|
||||
"not_allowed_local" => "Codex blocked this request: local/private addresses not allowed.",
|
||||
"denied" => "Codex blocked this request: domain denied by policy.",
|
||||
"method_not_allowed" => "Codex blocked this request: method not allowed in limited mode.",
|
||||
"mitm_required" => "Codex blocked this request: MITM required for limited HTTPS.",
|
||||
_ => "Codex blocked this request by network policy.",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn blocked_text_response(reason: &str) -> Response {
|
||||
Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "text/plain")
|
||||
.header("x-proxy-error", blocked_header_value(reason))
|
||||
.body(Body::from(blocked_message(reason)))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
|
||||
}
|
||||
909
codex-rs/network-proxy/src/runtime.rs
Normal file
909
codex-rs/network-proxy/src/runtime.rs
Normal file
@@ -0,0 +1,909 @@
|
||||
use crate::config::Config;
|
||||
use crate::config::NetworkMode;
|
||||
use crate::mitm::MitmState;
|
||||
use crate::policy::is_loopback_host;
|
||||
use crate::policy::is_non_public_ip;
|
||||
use crate::policy::method_allowed;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::state::NetworkProxyConstraints;
|
||||
use crate::state::build_config_state;
|
||||
use crate::state::validate_policy_against_constraints;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use globset::GlobSet;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
use std::net::IpAddr;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use time::OffsetDateTime;
|
||||
use tokio::net::lookup_host;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::timeout;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
const MAX_BLOCKED_EVENTS: usize = 200;
|
||||
const DNS_LOOKUP_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct BlockedRequest {
|
||||
pub host: String,
|
||||
pub reason: String,
|
||||
pub client: Option<String>,
|
||||
pub method: Option<String>,
|
||||
pub mode: Option<NetworkMode>,
|
||||
pub protocol: String,
|
||||
pub timestamp: i64,
|
||||
}
|
||||
|
||||
impl BlockedRequest {
|
||||
pub fn new(
|
||||
host: String,
|
||||
reason: String,
|
||||
client: Option<String>,
|
||||
method: Option<String>,
|
||||
mode: Option<NetworkMode>,
|
||||
protocol: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
host,
|
||||
reason,
|
||||
client,
|
||||
method,
|
||||
mode,
|
||||
protocol,
|
||||
timestamp: unix_timestamp(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ConfigState {
|
||||
pub(crate) config: Config,
|
||||
pub(crate) mtime: Option<SystemTime>,
|
||||
pub(crate) allow_set: GlobSet,
|
||||
pub(crate) deny_set: GlobSet,
|
||||
pub(crate) mitm: Option<Arc<MitmState>>,
|
||||
pub(crate) constraints: NetworkProxyConstraints,
|
||||
pub(crate) cfg_path: PathBuf,
|
||||
pub(crate) blocked: VecDeque<BlockedRequest>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
state: Arc<RwLock<ConfigState>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AppState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Avoid logging internal state (config contents, derived globsets, etc.) which can be noisy
|
||||
// and may contain sensitive paths.
|
||||
f.debug_struct("AppState").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new() -> Result<Self> {
|
||||
let cfg_state = build_config_state().await?;
|
||||
Ok(Self {
|
||||
state: Arc::new(RwLock::new(cfg_state)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn current_cfg(&self) -> Result<Config> {
|
||||
// Callers treat `AppState` as a live view of policy. We reload-on-demand so edits to
|
||||
// `config.toml` (including Codex-managed writes) take effect without a restart.
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.clone())
|
||||
}
|
||||
|
||||
pub async fn current_patterns(&self) -> Result<(Vec<String>, Vec<String>)> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok((
|
||||
guard.config.network_proxy.policy.allowed_domains.clone(),
|
||||
guard.config.network_proxy.policy.denied_domains.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn enabled(&self) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.network_proxy.enabled)
|
||||
}
|
||||
|
||||
pub async fn force_reload(&self) -> Result<()> {
|
||||
let mut guard = self.state.write().await;
|
||||
let previous_cfg = guard.config.clone();
|
||||
let blocked = guard.blocked.clone();
|
||||
match build_config_state().await {
|
||||
Ok(mut new_state) => {
|
||||
// Policy changes are operationally sensitive; logging diffs makes changes traceable
|
||||
// without needing to dump full config blobs (which can include unrelated settings).
|
||||
log_policy_changes(&previous_cfg, &new_state.config);
|
||||
new_state.blocked = blocked;
|
||||
*guard = new_state;
|
||||
let path = guard.cfg_path.display();
|
||||
info!("reloaded config from {path}");
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let path = guard.cfg_path.display();
|
||||
warn!("failed to reload config from {path}: {err}; keeping previous config");
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn host_blocked(&self, host: &str, port: u16) -> Result<(bool, String)> {
|
||||
self.reload_if_needed().await?;
|
||||
let (deny_set, allow_set, allow_local_binding, allowed_domains_empty, allowed_domains) = {
|
||||
let guard = self.state.read().await;
|
||||
(
|
||||
guard.deny_set.clone(),
|
||||
guard.allow_set.clone(),
|
||||
guard.config.network_proxy.policy.allow_local_binding,
|
||||
guard.config.network_proxy.policy.allowed_domains.is_empty(),
|
||||
guard.config.network_proxy.policy.allowed_domains.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
// Decision order matters:
|
||||
// 1) explicit deny always wins
|
||||
// 2) local/private networking is opt-in (defense-in-depth)
|
||||
// 3) allowlist is enforced when configured
|
||||
if deny_set.is_match(host) {
|
||||
return Ok((true, "denied".to_string()));
|
||||
}
|
||||
|
||||
let is_allowlisted = allow_set.is_match(host);
|
||||
if !allow_local_binding {
|
||||
// If the intent is "prevent access to local/internal networks", we must not rely solely
|
||||
// on string checks like `localhost` / `127.0.0.1`. Attackers can use DNS rebinding or
|
||||
// public suffix services that map hostnames onto private IPs.
|
||||
//
|
||||
// We therefore do a best-effort DNS + IP classification check before allowing the
|
||||
// request. Explicit local/loopback literals are allowed only when explicitly
|
||||
// allowlisted; hostnames that resolve to local/private IPs are blocked even if
|
||||
// allowlisted.
|
||||
let local_literal = {
|
||||
let host = host.trim();
|
||||
let host = host.split_once('%').map(|(ip, _)| ip).unwrap_or(host);
|
||||
if is_loopback_host(host) {
|
||||
true
|
||||
} else if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
is_non_public_ip(ip)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if local_literal {
|
||||
if !is_explicit_local_allowlisted(&allowed_domains, host) {
|
||||
return Ok((true, "not_allowed_local".to_string()));
|
||||
}
|
||||
} else if host_resolves_to_non_public_ip(host, port).await? {
|
||||
return Ok((true, "not_allowed_local".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
if allowed_domains_empty {
|
||||
return Ok((true, "not_allowed".to_string()));
|
||||
}
|
||||
|
||||
if !is_allowlisted {
|
||||
return Ok((true, "not_allowed".to_string()));
|
||||
}
|
||||
Ok((false, String::new()))
|
||||
}
|
||||
|
||||
pub async fn record_blocked(&self, entry: BlockedRequest) -> Result<()> {
|
||||
self.reload_if_needed().await?;
|
||||
let mut guard = self.state.write().await;
|
||||
guard.blocked.push_back(entry);
|
||||
while guard.blocked.len() > MAX_BLOCKED_EVENTS {
|
||||
guard.blocked.pop_front();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn drain_blocked(&self) -> Result<Vec<BlockedRequest>> {
|
||||
self.reload_if_needed().await?;
|
||||
let mut guard = self.state.write().await;
|
||||
let blocked = std::mem::take(&mut guard.blocked);
|
||||
Ok(blocked.into_iter().collect())
|
||||
}
|
||||
|
||||
pub async fn is_unix_socket_allowed(&self, path: &str) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
if cfg!(not(target_os = "macos")) {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// We only support absolute unix socket paths (a relative path would be ambiguous with
|
||||
// respect to the proxy process's CWD and can lead to confusing allowlist behavior).
|
||||
if !Path::new(path).is_absolute() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let guard = self.state.read().await;
|
||||
let requested_canonical = std::fs::canonicalize(path).ok();
|
||||
for allowed in &guard.config.network_proxy.policy.allow_unix_sockets {
|
||||
if allowed == path {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Best-effort canonicalization to reduce surprises with symlinks.
|
||||
// If canonicalization fails (e.g., socket not created yet), fall back to raw comparison.
|
||||
let Some(requested_canonical) = &requested_canonical else {
|
||||
continue;
|
||||
};
|
||||
if let Ok(allowed_canonical) = std::fs::canonicalize(allowed)
|
||||
&& &allowed_canonical == requested_canonical
|
||||
{
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
pub async fn method_allowed(&self, method: &str) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(method_allowed(guard.config.network_proxy.mode, method))
|
||||
}
|
||||
|
||||
pub async fn allow_upstream_proxy(&self) -> Result<bool> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.network_proxy.allow_upstream_proxy)
|
||||
}
|
||||
|
||||
pub async fn network_mode(&self) -> Result<NetworkMode> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.config.network_proxy.mode)
|
||||
}
|
||||
|
||||
pub async fn set_network_mode(&self, mode: NetworkMode) -> Result<()> {
|
||||
self.reload_if_needed().await?;
|
||||
let mut guard = self.state.write().await;
|
||||
let mut candidate = guard.config.clone();
|
||||
candidate.network_proxy.mode = mode;
|
||||
validate_policy_against_constraints(&candidate, &guard.constraints)
|
||||
.context("network_proxy.mode constrained by managed config")?;
|
||||
guard.config.network_proxy.mode = mode;
|
||||
info!("updated network mode to {mode:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn mitm_state(&self) -> Result<Option<Arc<MitmState>>> {
|
||||
self.reload_if_needed().await?;
|
||||
let guard = self.state.read().await;
|
||||
Ok(guard.mitm.clone())
|
||||
}
|
||||
|
||||
async fn reload_if_needed(&self) -> Result<()> {
|
||||
let needs_reload = {
|
||||
let guard = self.state.read().await;
|
||||
if !guard.cfg_path.exists() {
|
||||
// If the config file is missing, only reload when it *used to* exist (mtime set).
|
||||
// This avoids forcing a reload on every request when running with the default config.
|
||||
guard.mtime.is_some()
|
||||
} else {
|
||||
let metadata = std::fs::metadata(&guard.cfg_path).ok();
|
||||
match (metadata.and_then(|m| m.modified().ok()), guard.mtime) {
|
||||
(Some(new_mtime), Some(old_mtime)) => new_mtime > old_mtime,
|
||||
(Some(_), None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !needs_reload {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.force_reload().await
|
||||
}
|
||||
}
|
||||
|
||||
async fn host_resolves_to_non_public_ip(host: &str, port: u16) -> Result<bool> {
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return Ok(is_non_public_ip(ip));
|
||||
}
|
||||
|
||||
// If DNS lookup fails, default to "not local/private" rather than blocking. In practice, the
|
||||
// subsequent connect attempt will fail anyway, and blocking on transient resolver issues would
|
||||
// make the proxy fragile. The allowlist/denylist remains the primary control plane.
|
||||
let addrs = match timeout(DNS_LOOKUP_TIMEOUT, lookup_host((host, port))).await {
|
||||
Ok(Ok(addrs)) => addrs,
|
||||
Ok(Err(_)) | Err(_) => return Ok(false),
|
||||
};
|
||||
|
||||
for addr in addrs {
|
||||
if is_non_public_ip(addr.ip()) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
fn log_policy_changes(previous: &Config, next: &Config) {
|
||||
log_domain_list_changes(
|
||||
"allowlist",
|
||||
&previous.network_proxy.policy.allowed_domains,
|
||||
&next.network_proxy.policy.allowed_domains,
|
||||
);
|
||||
log_domain_list_changes(
|
||||
"denylist",
|
||||
&previous.network_proxy.policy.denied_domains,
|
||||
&next.network_proxy.policy.denied_domains,
|
||||
);
|
||||
}
|
||||
|
||||
fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]) {
|
||||
let previous_set: HashSet<String> = previous
|
||||
.iter()
|
||||
.map(|entry| entry.to_ascii_lowercase())
|
||||
.collect();
|
||||
let next_set: HashSet<String> = next
|
||||
.iter()
|
||||
.map(|entry| entry.to_ascii_lowercase())
|
||||
.collect();
|
||||
|
||||
let mut seen_next = HashSet::new();
|
||||
for entry in next {
|
||||
let key = entry.to_ascii_lowercase();
|
||||
if seen_next.insert(key.clone()) && !previous_set.contains(&key) {
|
||||
info!("config entry added to {list_name}: {entry}");
|
||||
}
|
||||
}
|
||||
|
||||
let mut seen_previous = HashSet::new();
|
||||
for entry in previous {
|
||||
let key = entry.to_ascii_lowercase();
|
||||
if seen_previous.insert(key.clone()) && !next_set.contains(&key) {
|
||||
info!("config entry removed from {list_name}: {entry}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_explicit_local_allowlisted(allowed_domains: &[String], host: &str) -> bool {
|
||||
let normalized_host = normalize_host(host);
|
||||
allowed_domains.iter().any(|pattern| {
|
||||
let pattern = pattern.trim();
|
||||
if pattern == "*" || pattern.starts_with("*.") || pattern.starts_with("**.") {
|
||||
return false;
|
||||
}
|
||||
if pattern.contains('*') || pattern.contains('?') {
|
||||
return false;
|
||||
}
|
||||
normalize_host(pattern) == normalized_host
|
||||
})
|
||||
}
|
||||
|
||||
fn unix_timestamp() -> i64 {
|
||||
OffsetDateTime::now_utc().unix_timestamp()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn app_state_for_policy(policy: crate::config::NetworkPolicy) -> AppState {
|
||||
let config = Config {
|
||||
network_proxy: crate::config::NetworkProxyConfig {
|
||||
enabled: true,
|
||||
mode: NetworkMode::Full,
|
||||
policy,
|
||||
..crate::config::NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
let allow_set =
|
||||
crate::policy::compile_globset(&config.network_proxy.policy.allowed_domains).unwrap();
|
||||
let deny_set =
|
||||
crate::policy::compile_globset(&config.network_proxy.policy.denied_domains).unwrap();
|
||||
|
||||
let state = ConfigState {
|
||||
config,
|
||||
mtime: None,
|
||||
allow_set,
|
||||
deny_set,
|
||||
mitm: None,
|
||||
constraints: NetworkProxyConstraints::default(),
|
||||
cfg_path: PathBuf::from("/nonexistent/config.toml"),
|
||||
blocked: VecDeque::new(),
|
||||
};
|
||||
|
||||
AppState {
|
||||
state: Arc::new(RwLock::new(state)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::config::NetworkPolicy;
|
||||
use crate::config::NetworkProxyConfig;
|
||||
use crate::policy::compile_globset;
|
||||
use crate::state::NetworkProxyConstraints;
|
||||
use crate::state::validate_policy_against_constraints;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_denied_wins_over_allowed() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
denied_domains: vec!["example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("example.com", 80).await.unwrap(),
|
||||
(true, "denied".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_requires_allowlist_match() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("example.com", 80).await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
assert_eq!(
|
||||
// Use a public IP literal to avoid relying on ambient DNS behavior (some networks
|
||||
// resolve unknown hostnames to private IPs, which would trigger `not_allowed_local`).
|
||||
state.host_blocked("8.8.8.8", 80).await.unwrap(),
|
||||
(true, "not_allowed".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_subdomain_wildcards_exclude_apex() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["*.openai.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("api.openai.com", 80).await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
assert_eq!(
|
||||
state.host_blocked("openai.com", 80).await.unwrap(),
|
||||
(true, "not_allowed".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_loopback_when_local_binding_disabled() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("127.0.0.1", 80).await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
state.host_blocked("localhost", 80).await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_loopback_when_allowlist_is_wildcard() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["*".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("127.0.0.1", 80).await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_private_ip_literal_when_allowlist_is_wildcard() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["*".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("10.0.0.1", 80).await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_allows_loopback_when_explicitly_allowlisted_and_local_binding_disabled() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["localhost".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("localhost", 80).await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_allows_private_ip_literal_when_explicitly_allowlisted() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["10.0.0.1".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("10.0.0.1", 80).await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_scoped_ipv6_literal_when_not_allowlisted() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("fe80::1%lo0", 80).await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_allows_scoped_ipv6_literal_when_explicitly_allowlisted() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["fe80::1%lo0".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("fe80::1%lo0", 80).await.unwrap(),
|
||||
(false, String::new())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_private_ip_literals_when_local_binding_disabled() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("10.0.0.1", 80).await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_blocked_rejects_loopback_when_allowlist_empty() {
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec![],
|
||||
allow_local_binding: false,
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
state.host_blocked("127.0.0.1", 80).await.unwrap(),
|
||||
(true, "not_allowed_local".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_widening_allowed_domains() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["example.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string(), "evil.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_widening_mode() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
mode: Some(NetworkMode::Limited),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
mode: NetworkMode::Full,
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_allows_narrowing_wildcard_allowlist() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["*.example.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allowed_domains: vec!["api.example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_rejects_widening_wildcard_allowlist() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allowed_domains: Some(vec!["*.example.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allowed_domains: vec!["**.example.com".to_string()],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_requires_managed_denied_domains_entries() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
denied_domains: Some(vec!["evil.com".to_string()]),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
denied_domains: vec![],
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_enabling_when_managed_disabled() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
enabled: Some(false),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_allow_local_binding_when_managed_disabled() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
allow_local_binding: Some(false),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
policy: NetworkPolicy {
|
||||
allow_local_binding: true,
|
||||
..NetworkPolicy::default()
|
||||
},
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_disallows_non_loopback_admin_without_managed_opt_in() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
dangerously_allow_non_loopback_admin: Some(false),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
dangerously_allow_non_loopback_admin: true,
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_policy_against_constraints_allows_non_loopback_admin_with_managed_opt_in() {
|
||||
let constraints = NetworkProxyConstraints {
|
||||
dangerously_allow_non_loopback_admin: Some(true),
|
||||
..NetworkProxyConstraints::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
network_proxy: NetworkProxyConfig {
|
||||
enabled: true,
|
||||
dangerously_allow_non_loopback_admin: true,
|
||||
..NetworkProxyConfig::default()
|
||||
},
|
||||
};
|
||||
|
||||
assert!(validate_policy_against_constraints(&config, &constraints).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_is_case_insensitive() {
|
||||
let patterns = vec!["ExAmPle.CoM".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("example.com"));
|
||||
assert!(set.is_match("EXAMPLE.COM"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_excludes_apex_for_subdomain_patterns() {
|
||||
let patterns = vec!["*.openai.com".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("api.openai.com"));
|
||||
assert!(!set.is_match("openai.com"));
|
||||
assert!(!set.is_match("evilopenai.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_includes_apex_for_double_wildcard_patterns() {
|
||||
let patterns = vec!["**.openai.com".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("openai.com"));
|
||||
assert!(set.is_match("api.openai.com"));
|
||||
assert!(!set.is_match("evilopenai.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_matches_all_with_star() {
|
||||
let patterns = vec!["*".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("openai.com"));
|
||||
assert!(set.is_match("api.openai.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_dedupes_patterns_without_changing_behavior() {
|
||||
let patterns = vec!["example.com".to_string(), "example.com".to_string()];
|
||||
let set = compile_globset(&patterns).unwrap();
|
||||
assert!(set.is_match("example.com"));
|
||||
assert!(set.is_match("EXAMPLE.COM"));
|
||||
assert!(!set.is_match("not-example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_globset_rejects_invalid_patterns() {
|
||||
let patterns = vec!["[".to_string()];
|
||||
assert!(compile_globset(&patterns).is_err());
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn unix_socket_allowlist_is_respected_on_macos() {
|
||||
let socket_path = "/tmp/example.sock".to_string();
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_unix_sockets: vec![socket_path.clone()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert!(state.is_unix_socket_allowed(&socket_path).await.unwrap());
|
||||
assert!(
|
||||
!state
|
||||
.is_unix_socket_allowed("/tmp/not-allowed.sock")
|
||||
.await
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn unix_socket_allowlist_resolves_symlinks() {
|
||||
use std::os::unix::fs::symlink;
|
||||
|
||||
let unique = OffsetDateTime::now_utc().unix_timestamp_nanos();
|
||||
let dir = std::env::temp_dir().join(format!("codex-network-proxy-test-{unique}"));
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
|
||||
let real = dir.join("real.sock");
|
||||
let link = dir.join("link.sock");
|
||||
|
||||
// The allowlist mechanism is path-based; for test purposes we don't need an actual unix
|
||||
// domain socket. Any filesystem entry works for canonicalization.
|
||||
std::fs::write(&real, b"not a socket").unwrap();
|
||||
symlink(&real, &link).unwrap();
|
||||
|
||||
let real_s = real.to_str().unwrap().to_string();
|
||||
let link_s = link.to_str().unwrap().to_string();
|
||||
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_unix_sockets: vec![real_s],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert!(state.is_unix_socket_allowed(&link_s).await.unwrap());
|
||||
|
||||
let _ = std::fs::remove_file(&link);
|
||||
let _ = std::fs::remove_file(&real);
|
||||
let _ = std::fs::remove_dir(&dir);
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
#[tokio::test]
|
||||
async fn unix_socket_allowlist_is_rejected_on_non_macos() {
|
||||
let socket_path = "/tmp/example.sock".to_string();
|
||||
let state = app_state_for_policy(NetworkPolicy {
|
||||
allowed_domains: vec!["example.com".to_string()],
|
||||
allow_unix_sockets: vec![socket_path.clone()],
|
||||
..NetworkPolicy::default()
|
||||
});
|
||||
|
||||
assert!(!state.is_unix_socket_allowed(&socket_path).await.unwrap());
|
||||
}
|
||||
}
|
||||
302
codex-rs/network-proxy/src/socks5.rs
Normal file
302
codex-rs/network-proxy/src/socks5.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
use crate::config::NetworkMode;
|
||||
use crate::network_policy::NetworkDecision;
|
||||
use crate::network_policy::NetworkPolicyDecider;
|
||||
use crate::network_policy::NetworkPolicyRequest;
|
||||
use crate::network_policy::NetworkProtocol;
|
||||
use crate::network_policy::evaluate_host_policy;
|
||||
use crate::policy::normalize_host;
|
||||
use crate::state::AppState;
|
||||
use crate::state::BlockedRequest;
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use rama_core::Layer;
|
||||
use rama_core::Service;
|
||||
use rama_core::extensions::ExtensionsRef;
|
||||
use rama_core::layer::AddInputExtensionLayer;
|
||||
use rama_core::service::service_fn;
|
||||
use rama_net::stream::SocketInfo;
|
||||
use rama_socks5::Socks5Acceptor;
|
||||
use rama_socks5::server::DefaultConnector;
|
||||
use rama_socks5::server::DefaultUdpRelay;
|
||||
use rama_socks5::server::udp::RelayRequest;
|
||||
use rama_socks5::server::udp::RelayResponse;
|
||||
use rama_tcp::client::Request as TcpRequest;
|
||||
use rama_tcp::client::service::TcpConnector;
|
||||
use rama_tcp::server::TcpListener;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub async fn run_socks5(
|
||||
state: Arc<AppState>,
|
||||
addr: SocketAddr,
|
||||
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
|
||||
enable_socks5_udp: bool,
|
||||
) -> Result<()> {
|
||||
let listener = TcpListener::build()
|
||||
.bind(addr)
|
||||
.await
|
||||
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
|
||||
.map_err(rama_core::error::OpaqueError::from)
|
||||
.map_err(anyhow::Error::from)
|
||||
.with_context(|| format!("bind SOCKS5 proxy: {addr}"))?;
|
||||
|
||||
info!("SOCKS5 proxy listening on {addr}");
|
||||
|
||||
match state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
info!("SOCKS5 is blocked in limited mode; set mode=\"full\" to allow SOCKS5");
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
warn!("failed to read network mode: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
let tcp_connector = TcpConnector::default();
|
||||
let policy_tcp_connector = service_fn({
|
||||
let policy_decider = policy_decider.clone();
|
||||
move |req: TcpRequest| {
|
||||
let tcp_connector = tcp_connector.clone();
|
||||
let policy_decider = policy_decider.clone();
|
||||
async move {
|
||||
let app_state = req
|
||||
.extensions()
|
||||
.get::<Arc<AppState>>()
|
||||
.cloned()
|
||||
.ok_or_else(|| io::Error::other("missing state"))?;
|
||||
|
||||
let host = normalize_host(&req.authority.host.to_string());
|
||||
let port = req.authority.port;
|
||||
let client = req
|
||||
.extensions()
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string());
|
||||
match app_state.enabled().await {
|
||||
Ok(true) => {}
|
||||
Ok(false) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"proxy_disabled".to_string(),
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("SOCKS blocked; proxy disabled (client={client}, host={host})");
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
"proxy disabled",
|
||||
)
|
||||
.into());
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to read enabled state: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
}
|
||||
match app_state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
client.clone(),
|
||||
None,
|
||||
Some(NetworkMode::Limited),
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!(
|
||||
"SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
|
||||
);
|
||||
return Err(
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
|
||||
);
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate method policy: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
}
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Socks5Tcp,
|
||||
host.clone(),
|
||||
port,
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
|
||||
Ok(NetworkDecision::Deny { reason }) => {
|
||||
let _ = app_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
"socks5".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!("SOCKS blocked (client={client}, host={host}, reason={reason})");
|
||||
return Err(
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, "blocked").into()
|
||||
);
|
||||
}
|
||||
Ok(NetworkDecision::Allow) => {
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
info!("SOCKS allowed (client={client}, host={host}, port={port})");
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate host: {err}");
|
||||
return Err(io::Error::other("proxy error").into());
|
||||
}
|
||||
}
|
||||
|
||||
tcp_connector.serve(req).await
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector);
|
||||
let base = Socks5Acceptor::new().with_connector(socks_connector);
|
||||
|
||||
if enable_socks5_udp {
|
||||
let udp_state = state.clone();
|
||||
let udp_decider = policy_decider.clone();
|
||||
let udp_relay = DefaultUdpRelay::default().with_async_inspector(service_fn(
|
||||
move |request: RelayRequest| {
|
||||
let udp_state = udp_state.clone();
|
||||
let udp_decider = udp_decider.clone();
|
||||
async move {
|
||||
let RelayRequest {
|
||||
server_address,
|
||||
payload,
|
||||
extensions,
|
||||
..
|
||||
} = request;
|
||||
|
||||
let host = normalize_host(&server_address.ip_addr.to_string());
|
||||
let port = server_address.port;
|
||||
let client = extensions
|
||||
.get::<SocketInfo>()
|
||||
.map(|info| info.peer_addr().to_string());
|
||||
match udp_state.enabled().await {
|
||||
Ok(true) => {}
|
||||
Ok(false) => {
|
||||
let _ = udp_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"proxy_disabled".to_string(),
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
"socks5-udp".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!(
|
||||
"SOCKS UDP blocked; proxy disabled (client={client}, host={host})"
|
||||
);
|
||||
return Ok(RelayResponse {
|
||||
maybe_payload: None,
|
||||
extensions,
|
||||
});
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to read enabled state: {err}");
|
||||
return Err(io::Error::other("proxy error"));
|
||||
}
|
||||
}
|
||||
|
||||
match udp_state.network_mode().await {
|
||||
Ok(NetworkMode::Limited) => {
|
||||
let _ = udp_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
"method_not_allowed".to_string(),
|
||||
client.clone(),
|
||||
None,
|
||||
Some(NetworkMode::Limited),
|
||||
"socks5-udp".to_string(),
|
||||
))
|
||||
.await;
|
||||
return Ok(RelayResponse {
|
||||
maybe_payload: None,
|
||||
extensions,
|
||||
});
|
||||
}
|
||||
Ok(NetworkMode::Full) => {}
|
||||
Err(err) => {
|
||||
error!("failed to evaluate method policy: {err}");
|
||||
return Err(io::Error::other("proxy error"));
|
||||
}
|
||||
}
|
||||
|
||||
let request = NetworkPolicyRequest::new(
|
||||
NetworkProtocol::Socks5Udp,
|
||||
host.clone(),
|
||||
port,
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
match evaluate_host_policy(&udp_state, udp_decider.as_ref(), &request).await {
|
||||
Ok(NetworkDecision::Deny { reason }) => {
|
||||
let _ = udp_state
|
||||
.record_blocked(BlockedRequest::new(
|
||||
host.clone(),
|
||||
reason.clone(),
|
||||
client.clone(),
|
||||
None,
|
||||
None,
|
||||
"socks5-udp".to_string(),
|
||||
))
|
||||
.await;
|
||||
let client = client.as_deref().unwrap_or_default();
|
||||
warn!(
|
||||
"SOCKS UDP blocked (client={client}, host={host}, reason={reason})"
|
||||
);
|
||||
Ok(RelayResponse {
|
||||
maybe_payload: None,
|
||||
extensions,
|
||||
})
|
||||
}
|
||||
Ok(NetworkDecision::Allow) => Ok(RelayResponse {
|
||||
maybe_payload: Some(payload),
|
||||
extensions,
|
||||
}),
|
||||
Err(err) => {
|
||||
error!("failed to evaluate UDP host: {err}");
|
||||
Err(io::Error::other("proxy error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
));
|
||||
let socks_acceptor = base.with_udp_associator(udp_relay);
|
||||
listener
|
||||
.serve(AddInputExtensionLayer::new(state).into_layer(socks_acceptor))
|
||||
.await;
|
||||
} else {
|
||||
listener
|
||||
.serve(AddInputExtensionLayer::new(state).into_layer(base))
|
||||
.await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
425
codex-rs/network-proxy/src/state.rs
Normal file
425
codex-rs/network-proxy/src/state.rs
Normal file
@@ -0,0 +1,425 @@
|
||||
use crate::config::Config;
|
||||
use crate::config::MitmConfig;
|
||||
use crate::config::NetworkMode;
|
||||
use crate::mitm::MitmState;
|
||||
use crate::policy::DomainPattern;
|
||||
use crate::policy::compile_globset;
|
||||
use crate::runtime::ConfigState;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
use codex_core::config::CONFIG_TOML_FILE;
|
||||
use codex_core::config::ConfigBuilder;
|
||||
use codex_core::config::Constrained;
|
||||
use codex_core::config::ConstraintError;
|
||||
use codex_core::config_loader::RequirementSource;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashSet;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use crate::runtime::AppState;
|
||||
pub use crate::runtime::BlockedRequest;
|
||||
#[cfg(test)]
|
||||
pub(crate) use crate::runtime::app_state_for_policy;
|
||||
|
||||
pub(crate) async fn build_config_state() -> Result<ConfigState> {
|
||||
// Load config through `codex-core` so we inherit the same layer ordering and semantics as the
|
||||
// rest of Codex (system/managed layers, user layers, session flags, etc.).
|
||||
let codex_cfg = ConfigBuilder::default()
|
||||
.build()
|
||||
.await
|
||||
.context("failed to load Codex config")?;
|
||||
|
||||
let cfg_path = codex_cfg.codex_home.join(CONFIG_TOML_FILE);
|
||||
|
||||
// Deserialize from the merged effective config, rather than parsing config.toml ourselves.
|
||||
// This avoids a second parser/merger implementation (and the drift that comes with it).
|
||||
let merged_toml = codex_cfg.config_layer_stack.effective_config();
|
||||
let mut config: Config = merged_toml
|
||||
.try_into()
|
||||
.context("failed to deserialize network proxy config")?;
|
||||
|
||||
// Security boundary: user-controlled layers must not be able to widen restrictions set by
|
||||
// trusted/managed layers (e.g., MDM). Enforce this before building runtime state.
|
||||
let constraints = enforce_trusted_constraints(&codex_cfg.config_layer_stack, &config)?;
|
||||
|
||||
// Permit relative MITM paths for ergonomics; resolve them relative to CODEX_HOME so the
|
||||
// proxy can be configured from multiple config locations without changing cert paths.
|
||||
resolve_mitm_paths(&mut config, &codex_cfg.codex_home);
|
||||
let mtime = cfg_path.metadata().and_then(|m| m.modified()).ok();
|
||||
let deny_set = compile_globset(&config.network_proxy.policy.denied_domains)?;
|
||||
let allow_set = compile_globset(&config.network_proxy.policy.allowed_domains)?;
|
||||
let mitm = if config.network_proxy.mitm.enabled {
|
||||
build_mitm_state(
|
||||
&config.network_proxy.mitm,
|
||||
config.network_proxy.allow_upstream_proxy,
|
||||
)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(ConfigState {
|
||||
config,
|
||||
mtime,
|
||||
allow_set,
|
||||
deny_set,
|
||||
mitm,
|
||||
constraints,
|
||||
cfg_path,
|
||||
blocked: std::collections::VecDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn resolve_mitm_paths(config: &mut Config, codex_home: &Path) {
|
||||
let base = codex_home;
|
||||
if config.network_proxy.mitm.ca_cert_path.is_relative() {
|
||||
config.network_proxy.mitm.ca_cert_path = base.join(&config.network_proxy.mitm.ca_cert_path);
|
||||
}
|
||||
if config.network_proxy.mitm.ca_key_path.is_relative() {
|
||||
config.network_proxy.mitm.ca_key_path = base.join(&config.network_proxy.mitm.ca_key_path);
|
||||
}
|
||||
}
|
||||
|
||||
fn build_mitm_state(
|
||||
config: &MitmConfig,
|
||||
allow_upstream_proxy: bool,
|
||||
) -> Result<Option<Arc<MitmState>>> {
|
||||
Ok(Some(Arc::new(MitmState::new(
|
||||
config,
|
||||
allow_upstream_proxy,
|
||||
)?)))
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialConfig {
|
||||
#[serde(default)]
|
||||
network_proxy: PartialNetworkProxyConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialNetworkProxyConfig {
|
||||
enabled: Option<bool>,
|
||||
mode: Option<NetworkMode>,
|
||||
allow_upstream_proxy: Option<bool>,
|
||||
dangerously_allow_non_loopback_proxy: Option<bool>,
|
||||
dangerously_allow_non_loopback_admin: Option<bool>,
|
||||
#[serde(default)]
|
||||
policy: PartialNetworkPolicy,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
struct PartialNetworkPolicy {
|
||||
#[serde(default)]
|
||||
allowed_domains: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
denied_domains: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
allow_unix_sockets: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
allow_local_binding: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub(crate) struct NetworkProxyConstraints {
|
||||
pub(crate) enabled: Option<bool>,
|
||||
pub(crate) mode: Option<NetworkMode>,
|
||||
pub(crate) allow_upstream_proxy: Option<bool>,
|
||||
pub(crate) dangerously_allow_non_loopback_proxy: Option<bool>,
|
||||
pub(crate) dangerously_allow_non_loopback_admin: Option<bool>,
|
||||
pub(crate) allowed_domains: Option<Vec<String>>,
|
||||
pub(crate) denied_domains: Option<Vec<String>>,
|
||||
pub(crate) allow_unix_sockets: Option<Vec<String>>,
|
||||
pub(crate) allow_local_binding: Option<bool>,
|
||||
}
|
||||
|
||||
fn enforce_trusted_constraints(
|
||||
layers: &codex_core::config_loader::ConfigLayerStack,
|
||||
config: &Config,
|
||||
) -> Result<NetworkProxyConstraints> {
|
||||
let constraints = network_proxy_constraints_from_trusted_layers(layers)?;
|
||||
validate_policy_against_constraints(config, &constraints)
|
||||
.context("network proxy constraints")?;
|
||||
Ok(constraints)
|
||||
}
|
||||
|
||||
fn network_proxy_constraints_from_trusted_layers(
|
||||
layers: &codex_core::config_loader::ConfigLayerStack,
|
||||
) -> Result<NetworkProxyConstraints> {
|
||||
let mut constraints = NetworkProxyConstraints::default();
|
||||
for layer in layers
|
||||
.get_layers(codex_core::config_loader::ConfigLayerStackOrdering::LowestPrecedenceFirst)
|
||||
{
|
||||
// Only trusted layers contribute constraints. User-controlled layers can narrow policy but
|
||||
// must never widen beyond what managed config allows.
|
||||
if is_user_controlled_layer(&layer.name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let partial: PartialConfig = layer
|
||||
.config
|
||||
.clone()
|
||||
.try_into()
|
||||
.context("failed to deserialize trusted config layer")?;
|
||||
|
||||
if let Some(enabled) = partial.network_proxy.enabled {
|
||||
constraints.enabled = Some(enabled);
|
||||
}
|
||||
if let Some(mode) = partial.network_proxy.mode {
|
||||
constraints.mode = Some(mode);
|
||||
}
|
||||
if let Some(allow_upstream_proxy) = partial.network_proxy.allow_upstream_proxy {
|
||||
constraints.allow_upstream_proxy = Some(allow_upstream_proxy);
|
||||
}
|
||||
if let Some(dangerously_allow_non_loopback_proxy) =
|
||||
partial.network_proxy.dangerously_allow_non_loopback_proxy
|
||||
{
|
||||
constraints.dangerously_allow_non_loopback_proxy =
|
||||
Some(dangerously_allow_non_loopback_proxy);
|
||||
}
|
||||
if let Some(dangerously_allow_non_loopback_admin) =
|
||||
partial.network_proxy.dangerously_allow_non_loopback_admin
|
||||
{
|
||||
constraints.dangerously_allow_non_loopback_admin =
|
||||
Some(dangerously_allow_non_loopback_admin);
|
||||
}
|
||||
|
||||
if let Some(allowed_domains) = partial.network_proxy.policy.allowed_domains {
|
||||
constraints.allowed_domains = Some(allowed_domains);
|
||||
}
|
||||
if let Some(denied_domains) = partial.network_proxy.policy.denied_domains {
|
||||
constraints.denied_domains = Some(denied_domains);
|
||||
}
|
||||
if let Some(allow_unix_sockets) = partial.network_proxy.policy.allow_unix_sockets {
|
||||
constraints.allow_unix_sockets = Some(allow_unix_sockets);
|
||||
}
|
||||
if let Some(allow_local_binding) = partial.network_proxy.policy.allow_local_binding {
|
||||
constraints.allow_local_binding = Some(allow_local_binding);
|
||||
}
|
||||
}
|
||||
Ok(constraints)
|
||||
}
|
||||
|
||||
fn is_user_controlled_layer(layer: &ConfigLayerSource) -> bool {
|
||||
matches!(
|
||||
layer,
|
||||
ConfigLayerSource::User { .. }
|
||||
| ConfigLayerSource::Project { .. }
|
||||
| ConfigLayerSource::SessionFlags
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn validate_policy_against_constraints(
|
||||
config: &Config,
|
||||
constraints: &NetworkProxyConstraints,
|
||||
) -> std::result::Result<(), ConstraintError> {
|
||||
fn invalid_value(
|
||||
field_name: &'static str,
|
||||
candidate: impl Into<String>,
|
||||
allowed: impl Into<String>,
|
||||
) -> ConstraintError {
|
||||
ConstraintError::InvalidValue {
|
||||
field_name,
|
||||
candidate: candidate.into(),
|
||||
allowed: allowed.into(),
|
||||
requirement_source: RequirementSource::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
let enabled = config.network_proxy.enabled;
|
||||
if let Some(max_enabled) = constraints.enabled {
|
||||
let _ = Constrained::new(enabled, move |candidate| {
|
||||
if *candidate && !max_enabled {
|
||||
Err(invalid_value(
|
||||
"network_proxy.enabled",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
if let Some(max_mode) = constraints.mode {
|
||||
let _ = Constrained::new(config.network_proxy.mode, move |candidate| {
|
||||
if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
|
||||
Err(invalid_value(
|
||||
"network_proxy.mode",
|
||||
format!("{candidate:?}"),
|
||||
format!("{max_mode:?} or more restrictive"),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
let allow_upstream_proxy = constraints.allow_upstream_proxy;
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.allow_upstream_proxy,
|
||||
move |candidate| match allow_upstream_proxy {
|
||||
Some(true) | None => Ok(()),
|
||||
Some(false) => {
|
||||
if *candidate {
|
||||
Err(invalid_value(
|
||||
"network_proxy.allow_upstream_proxy",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
let allow_non_loopback_admin = constraints.dangerously_allow_non_loopback_admin;
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.dangerously_allow_non_loopback_admin,
|
||||
move |candidate| match allow_non_loopback_admin {
|
||||
Some(true) | None => Ok(()),
|
||||
Some(false) => {
|
||||
if *candidate {
|
||||
Err(invalid_value(
|
||||
"network_proxy.dangerously_allow_non_loopback_admin",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
let allow_non_loopback_proxy = constraints.dangerously_allow_non_loopback_proxy;
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.dangerously_allow_non_loopback_proxy,
|
||||
move |candidate| match allow_non_loopback_proxy {
|
||||
Some(true) | None => Ok(()),
|
||||
Some(false) => {
|
||||
if *candidate {
|
||||
Err(invalid_value(
|
||||
"network_proxy.dangerously_allow_non_loopback_proxy",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
if let Some(allow_local_binding) = constraints.allow_local_binding {
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allow_local_binding,
|
||||
move |candidate| {
|
||||
if *candidate && !allow_local_binding {
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allow_local_binding",
|
||||
"true",
|
||||
"false (disabled by managed config)",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(allowed_domains) = &constraints.allowed_domains {
|
||||
let managed_patterns: Vec<DomainPattern> = allowed_domains
|
||||
.iter()
|
||||
.map(|entry| DomainPattern::parse(entry))
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allowed_domains.clone(),
|
||||
move |candidate| {
|
||||
let mut invalid = Vec::new();
|
||||
for entry in candidate {
|
||||
let candidate_pattern = DomainPattern::parse(entry);
|
||||
if !managed_patterns
|
||||
.iter()
|
||||
.any(|managed| managed.allows(&candidate_pattern))
|
||||
{
|
||||
invalid.push(entry.clone());
|
||||
}
|
||||
}
|
||||
if invalid.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allowed_domains",
|
||||
format!("{invalid:?}"),
|
||||
"subset of managed allowed_domains",
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(denied_domains) = &constraints.denied_domains {
|
||||
let required_set: HashSet<String> = denied_domains
|
||||
.iter()
|
||||
.map(|s| s.to_ascii_lowercase())
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.denied_domains.clone(),
|
||||
move |candidate| {
|
||||
let candidate_set: HashSet<String> =
|
||||
candidate.iter().map(|s| s.to_ascii_lowercase()).collect();
|
||||
let missing: Vec<String> = required_set
|
||||
.iter()
|
||||
.filter(|entry| !candidate_set.contains(*entry))
|
||||
.cloned()
|
||||
.collect();
|
||||
if missing.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.denied_domains",
|
||||
"missing managed denied_domains entries",
|
||||
format!("{missing:?}"),
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(allow_unix_sockets) = &constraints.allow_unix_sockets {
|
||||
let allowed_set: HashSet<String> = allow_unix_sockets
|
||||
.iter()
|
||||
.map(|s| s.to_ascii_lowercase())
|
||||
.collect();
|
||||
let _ = Constrained::new(
|
||||
config.network_proxy.policy.allow_unix_sockets.clone(),
|
||||
move |candidate| {
|
||||
let mut invalid = Vec::new();
|
||||
for entry in candidate {
|
||||
if !allowed_set.contains(&entry.to_ascii_lowercase()) {
|
||||
invalid.push(entry.clone());
|
||||
}
|
||||
}
|
||||
if invalid.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(invalid_value(
|
||||
"network_proxy.policy.allow_unix_sockets",
|
||||
format!("{invalid:?}"),
|
||||
"subset of managed allow_unix_sockets",
|
||||
))
|
||||
}
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn network_mode_rank(mode: NetworkMode) -> u8 {
|
||||
match mode {
|
||||
NetworkMode::Limited => 0,
|
||||
NetworkMode::Full => 1,
|
||||
}
|
||||
}
|
||||
188
codex-rs/network-proxy/src/upstream.rs
Normal file
188
codex-rs/network-proxy/src/upstream.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
use rama_core::Layer;
|
||||
use rama_core::Service;
|
||||
use rama_core::error::BoxError;
|
||||
use rama_core::error::ErrorContext as _;
|
||||
use rama_core::error::OpaqueError;
|
||||
use rama_core::extensions::ExtensionsMut;
|
||||
use rama_core::extensions::ExtensionsRef;
|
||||
use rama_core::service::BoxService;
|
||||
use rama_http::Body;
|
||||
use rama_http::Request;
|
||||
use rama_http::Response;
|
||||
use rama_http::layer::version_adapter::RequestVersionAdapter;
|
||||
use rama_http_backend::client::HttpClientService;
|
||||
use rama_http_backend::client::HttpConnector;
|
||||
use rama_http_backend::client::proxy::layer::HttpProxyConnectorLayer;
|
||||
use rama_net::address::ProxyAddress;
|
||||
use rama_net::client::EstablishedClientConnection;
|
||||
use rama_net::http::RequestContext;
|
||||
use rama_tcp::client::service::TcpConnector;
|
||||
use rama_tls_boring::client::TlsConnectorDataBuilder;
|
||||
use rama_tls_boring::client::TlsConnectorLayer;
|
||||
use tracing::warn;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
use rama_unix::client::UnixConnector;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct ProxyConfig {
|
||||
http: Option<ProxyAddress>,
|
||||
https: Option<ProxyAddress>,
|
||||
all: Option<ProxyAddress>,
|
||||
}
|
||||
|
||||
impl ProxyConfig {
|
||||
fn from_env() -> Self {
|
||||
let http = read_proxy_env(&["HTTP_PROXY", "http_proxy"]);
|
||||
let https = read_proxy_env(&["HTTPS_PROXY", "https_proxy"]);
|
||||
let all = read_proxy_env(&["ALL_PROXY", "all_proxy"]);
|
||||
Self { http, https, all }
|
||||
}
|
||||
|
||||
fn proxy_for_request(&self, req: &Request) -> Option<ProxyAddress> {
|
||||
let is_secure = RequestContext::try_from(req)
|
||||
.map(|ctx| ctx.protocol.is_secure())
|
||||
.unwrap_or(false);
|
||||
self.proxy_for_protocol(is_secure)
|
||||
}
|
||||
|
||||
fn proxy_for_protocol(&self, is_secure: bool) -> Option<ProxyAddress> {
|
||||
if is_secure {
|
||||
self.https
|
||||
.clone()
|
||||
.or_else(|| self.http.clone())
|
||||
.or_else(|| self.all.clone())
|
||||
} else {
|
||||
self.http.clone().or_else(|| self.all.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_proxy_env(keys: &[&str]) -> Option<ProxyAddress> {
|
||||
for key in keys {
|
||||
let Ok(value) = std::env::var(key) else {
|
||||
continue;
|
||||
};
|
||||
let value = value.trim();
|
||||
if value.is_empty() {
|
||||
continue;
|
||||
}
|
||||
match ProxyAddress::try_from(value) {
|
||||
Ok(proxy) => {
|
||||
if proxy
|
||||
.protocol
|
||||
.as_ref()
|
||||
.map(rama_net::Protocol::is_http)
|
||||
.unwrap_or(true)
|
||||
{
|
||||
return Some(proxy);
|
||||
}
|
||||
warn!("ignoring {key}: non-http proxy protocol");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("ignoring {key}: invalid proxy address ({err})");
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn proxy_for_connect() -> Option<ProxyAddress> {
|
||||
ProxyConfig::from_env().proxy_for_protocol(true)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct UpstreamClient {
|
||||
connector: BoxService<
|
||||
Request<Body>,
|
||||
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
|
||||
BoxError,
|
||||
>,
|
||||
proxy_config: ProxyConfig,
|
||||
}
|
||||
|
||||
impl UpstreamClient {
|
||||
pub(crate) fn direct() -> Self {
|
||||
Self::new(ProxyConfig::default())
|
||||
}
|
||||
|
||||
pub(crate) fn from_env_proxy() -> Self {
|
||||
Self::new(ProxyConfig::from_env())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub(crate) fn unix_socket(path: &str) -> Self {
|
||||
let connector = build_unix_connector(path);
|
||||
Self {
|
||||
connector,
|
||||
proxy_config: ProxyConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn new(proxy_config: ProxyConfig) -> Self {
|
||||
let connector = build_http_connector();
|
||||
Self {
|
||||
connector,
|
||||
proxy_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Service<Request<Body>> for UpstreamClient {
|
||||
type Output = Response;
|
||||
type Error = OpaqueError;
|
||||
|
||||
async fn serve(&self, mut req: Request<Body>) -> Result<Self::Output, Self::Error> {
|
||||
if let Some(proxy) = self.proxy_config.proxy_for_request(&req) {
|
||||
req.extensions_mut().insert(proxy);
|
||||
}
|
||||
|
||||
let uri = req.uri().clone();
|
||||
let EstablishedClientConnection {
|
||||
input: mut req,
|
||||
conn: http_connection,
|
||||
} = self
|
||||
.connector
|
||||
.serve(req)
|
||||
.await
|
||||
.map_err(OpaqueError::from_boxed)?;
|
||||
|
||||
req.extensions_mut()
|
||||
.extend(http_connection.extensions().clone());
|
||||
|
||||
http_connection
|
||||
.serve(req)
|
||||
.await
|
||||
.map_err(OpaqueError::from_boxed)
|
||||
.with_context(|| format!("http request failure for uri: {uri}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_http_connector() -> BoxService<
|
||||
Request<Body>,
|
||||
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
|
||||
BoxError,
|
||||
> {
|
||||
let transport = TcpConnector::default();
|
||||
let proxy = HttpProxyConnectorLayer::optional().into_layer(transport);
|
||||
let tls_config = TlsConnectorDataBuilder::new_http_auto().into_shared_builder();
|
||||
let tls = TlsConnectorLayer::auto()
|
||||
.with_connector_data(tls_config)
|
||||
.into_layer(proxy);
|
||||
let tls = RequestVersionAdapter::new(tls);
|
||||
let connector = HttpConnector::new(tls);
|
||||
connector.boxed()
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
fn build_unix_connector(
|
||||
path: &str,
|
||||
) -> BoxService<
|
||||
Request<Body>,
|
||||
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
|
||||
BoxError,
|
||||
> {
|
||||
let transport = UnixConnector::fixed(path);
|
||||
let connector = HttpConnector::new(transport);
|
||||
connector.boxed()
|
||||
}
|
||||
Reference in New Issue
Block a user