Skip to main content

tork_core/middleware/
proxy_headers.rs

1//! Proxy-header normalization middleware.
2
3use std::net::IpAddr;
4
5use http::header::HOST;
6use ipnet::IpNet;
7use tracing::warn;
8
9use crate::error::Result;
10use crate::extract::{peer_addr_from_extensions, RequestScheme};
11use crate::middleware::{DuplicatePolicy, Middleware, Next, Request};
12use crate::response::Response;
13use crate::router::BoxFuture;
14
15/// Header conveying the original host through a terminating proxy.
16const FORWARDED_HOST: &str = "x-forwarded-host";
17/// Header conveying the original scheme through a terminating proxy.
18const FORWARDED_PROTO: &str = "x-forwarded-proto";
19
20/// Normalizes proxy-forwarded headers onto the request.
21///
22/// When the request comes through a terminating proxy, the original host arrives
23/// in `X-Forwarded-Host`; this middleware rewrites the `Host` header from it so
24/// downstream host-based middleware (such as [`TrustedHost`](super::TrustedHost))
25/// sees the client-facing host. Register it before those middlewares. The
26/// forwarded scheme is honored directly by
27/// [`HttpsRedirect`](super::HttpsRedirect).
28pub struct ProxyHeaders {
29    trusted_ips: Vec<IpAddr>,
30    trusted_cidrs: Vec<IpNet>,
31}
32
33impl ProxyHeaders {
34    /// Creates the middleware.
35    pub fn new() -> Self {
36        Self {
37            trusted_ips: Vec::new(),
38            trusted_cidrs: Vec::new(),
39        }
40    }
41
42    /// Trusts a single reverse proxy address.
43    pub fn trust_proxy(mut self, addr: IpAddr) -> Self {
44        self.trusted_ips.push(addr);
45        self
46    }
47
48    /// Trusts a reverse-proxy network.
49    pub fn trust_cidr(mut self, network: IpNet) -> Self {
50        self.trusted_cidrs.push(network);
51        self
52    }
53
54    /// Trusts loopback reverse proxies (`127.0.0.1` and `::1`).
55    pub fn trust_loopback(self) -> Self {
56        self.trust_proxy(IpAddr::from([127, 0, 0, 1]))
57            .trust_proxy(IpAddr::from(std::net::Ipv6Addr::LOCALHOST))
58    }
59
60    fn is_trusted(&self, request: &Request) -> bool {
61        let Some(peer) = peer_addr_from_extensions(request.extensions()) else {
62            return false;
63        };
64        self.trusted_ips.iter().any(|addr| *addr == peer.ip())
65            || self
66                .trusted_cidrs
67                .iter()
68                .any(|network| network.contains(&peer.ip()))
69    }
70
71    fn forwarded_value<'a>(request: &'a Request, name: &'static str) -> Option<&'a str> {
72        request
73            .headers()
74            .get(name)
75            .and_then(|value| value.to_str().ok())
76            .and_then(|value| value.split(',').next())
77            .map(str::trim)
78            .filter(|value| !value.is_empty())
79    }
80}
81
82impl Default for ProxyHeaders {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl Middleware for ProxyHeaders {
89    fn handle(&self, mut request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
90        if !self.is_trusted(&request) {
91            return next.run(request);
92        }
93
94        if let Some(forwarded_host) = Self::forwarded_value(&request, FORWARDED_HOST) {
95            if let Ok(value) = http::HeaderValue::from_str(forwarded_host) {
96                request.headers_mut().insert(HOST, value);
97            }
98        }
99
100        if let Some(forwarded_proto) = Self::forwarded_value(&request, FORWARDED_PROTO) {
101            let scheme = if forwarded_proto.eq_ignore_ascii_case("https") {
102                Some(RequestScheme::Https)
103            } else if forwarded_proto.eq_ignore_ascii_case("http") {
104                Some(RequestScheme::Http)
105            } else {
106                None
107            };
108
109            if let Some(scheme) = scheme {
110                request.extensions_mut().insert(scheme);
111            } else {
112                warn!("tork: ignoring unsupported X-Forwarded-Proto value");
113            }
114        }
115        next.run(request)
116    }
117
118    fn name(&self) -> &'static str {
119        "ProxyHeaders"
120    }
121
122    fn duplicate_policy(&self) -> DuplicatePolicy {
123        DuplicatePolicy::Reject
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn builtin_metadata_is_stable() {
133        let middleware = ProxyHeaders::new();
134        assert_eq!(middleware.name(), "ProxyHeaders");
135        assert_eq!(middleware.duplicate_policy(), DuplicatePolicy::Reject);
136    }
137
138    #[test]
139    fn default_impl_uses_new() {
140        // Default must produce the same value as new().
141        let middleware: ProxyHeaders = Default::default();
142        assert!(middleware.trusted_ips.is_empty());
143        assert!(middleware.trusted_cidrs.is_empty());
144    }
145
146    #[test]
147    fn trust_builders_register_expected_networks() {
148        let middleware = ProxyHeaders::new()
149            .trust_proxy(IpAddr::from([10, 0, 0, 1]))
150            .trust_cidr("10.0.0.0/24".parse().unwrap())
151            .trust_loopback();
152
153        assert!(middleware
154            .trusted_ips
155            .contains(&IpAddr::from([10, 0, 0, 1])));
156        assert!(middleware
157            .trusted_cidrs
158            .contains(&"10.0.0.0/24".parse().unwrap()));
159        assert!(middleware
160            .trusted_ips
161            .contains(&IpAddr::from([127, 0, 0, 1])));
162    }
163}