refactor(network-proxy): use rama host parsing

This commit is contained in:
viyatb-oai
2026-03-02 12:36:21 -08:00
parent ff129aba82
commit d5aca19239

View File

@@ -43,6 +43,8 @@ use rama_http::Request;
use rama_http::Response;
use rama_http::StatusCode;
use rama_http::header;
use rama_http::headers::HeaderMapExt;
use rama_http::headers::Host;
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
use rama_http::matcher::MethodMatcher;
use rama_http_backend::client::proxy::layer::HttpProxyConnector;
@@ -464,16 +466,17 @@ async fn http_plain_proxy(
};
}
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
Ok(authority) => authority,
let request_ctx = match RequestContext::try_from(&req) {
Ok(request_ctx) => request_ctx,
Err(err) => {
warn!("missing host: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
}
};
let authority = request_ctx.host_with_port();
let host = normalize_host(&authority.host.to_string());
let port = authority.port;
if let Err(reason) = validate_absolute_form_host_header(&req, &host, port) {
if let Err(reason) = validate_absolute_form_host_header(&req, &request_ctx) {
let client = client.as_deref().unwrap_or_default();
let host_header = req
.headers()
@@ -663,83 +666,40 @@ fn request_network_attempt_id(req: &Request) -> Option<String> {
.or_else(|| attempt_id_from_proxy_authorization(req.headers().get("authorization")))
}
#[derive(Debug, PartialEq, Eq)]
struct HostHeaderAuthority {
host: String,
port: Option<u16>,
}
fn validate_absolute_form_host_header(
req: &Request,
authority_host: &str,
authority_port: u16,
request_ctx: &RequestContext,
) -> Result<(), &'static str> {
if req.uri().scheme_str().is_none() {
return Ok(());
}
let Some(host_header) = req.headers().get(header::HOST) else {
let Some(host_header) = req
.headers()
.typed_try_get::<Host>()
.map_err(|_| "invalid Host header")?
else {
return Ok(());
};
let parsed = parse_host_header_authority(host_header).ok_or("invalid Host header")?;
if parsed.host != authority_host {
if host_header.0.host != request_ctx.authority.host {
return Err("Host header does not match request target");
}
if let Some(host_port) = parsed.port {
if host_port != authority_port {
if let Some(host_port) = host_header.0.port {
if Some(host_port) != request_ctx.authority.port {
return Err("Host header does not match request target");
}
return Ok(());
}
let target_port_was_explicit = req.uri().port_u16().is_some();
let default_port = req.uri().scheme_str().and_then(default_port_for_scheme);
if target_port_was_explicit && default_port != Some(authority_port) {
if !request_ctx.authority_has_default_port() {
return Err("Host header does not match request target");
}
Ok(())
}
fn parse_host_header_authority(value: &HeaderValue) -> Option<HostHeaderAuthority> {
let raw = value.to_str().ok()?.trim();
if raw.is_empty() {
return None;
}
let host = normalize_host(raw);
if host.is_empty() {
return None;
}
let port = if raw.starts_with('[') {
let end = raw.find(']')?;
let remainder = &raw[end + 1..];
if remainder.is_empty() {
None
} else {
Some(remainder.strip_prefix(':')?.parse::<u16>().ok()?)
}
} else if raw.bytes().filter(|byte| *byte == b':').count() == 1 {
let (_, port) = raw.rsplit_once(':')?;
Some(port.parse::<u16>().ok()?)
} else {
None
};
Some(HostHeaderAuthority { host, port })
}
fn default_port_for_scheme(scheme: &str) -> Option<u16> {
match scheme {
"http" | "ws" => Some(80),
"https" | "wss" => Some(443),
_ => None,
}
}
fn remove_hop_by_hop_request_headers(headers: &mut HeaderMap) {
while let Some(raw_connection) = headers.get(header::CONNECTION).cloned() {
headers.remove(header::CONNECTION);
@@ -1022,7 +982,7 @@ mod tests {
.unwrap();
assert_eq!(
validate_absolute_form_host_header(&req, "example.com", 80),
validate_absolute_form_host_header(&req, &RequestContext::try_from(&req).unwrap(),),
Ok(())
);
}
@@ -1037,7 +997,7 @@ mod tests {
.unwrap();
assert_eq!(
validate_absolute_form_host_header(&req, "raw.githubusercontent.com", 80),
validate_absolute_form_host_header(&req, &RequestContext::try_from(&req).unwrap(),),
Err("Host header does not match request target")
);
}
@@ -1052,7 +1012,7 @@ mod tests {
.unwrap();
assert_eq!(
validate_absolute_form_host_header(&req, "example.com", 8080),
validate_absolute_form_host_header(&req, &RequestContext::try_from(&req).unwrap(),),
Err("Host header does not match request target")
);
}