tork_core/middleware/
proxy_headers.rs1use std::net::IpAddr;
4
5use http::header::HOST;
6use ipnet::IpNet;
7use tracing::warn;
8
9use crate::error::Result;
10use crate::extract::{peer_addr_from_extensions, RequestScheme};
11use crate::middleware::{DuplicatePolicy, Middleware, Next, Request};
12use crate::response::Response;
13use crate::router::BoxFuture;
14
15const FORWARDED_HOST: &str = "x-forwarded-host";
17const FORWARDED_PROTO: &str = "x-forwarded-proto";
19
20pub struct ProxyHeaders {
29 trusted_ips: Vec<IpAddr>,
30 trusted_cidrs: Vec<IpNet>,
31}
32
33impl ProxyHeaders {
34 pub fn new() -> Self {
36 Self {
37 trusted_ips: Vec::new(),
38 trusted_cidrs: Vec::new(),
39 }
40 }
41
42 pub fn trust_proxy(mut self, addr: IpAddr) -> Self {
44 self.trusted_ips.push(addr);
45 self
46 }
47
48 pub fn trust_cidr(mut self, network: IpNet) -> Self {
50 self.trusted_cidrs.push(network);
51 self
52 }
53
54 pub fn trust_loopback(self) -> Self {
56 self.trust_proxy(IpAddr::from([127, 0, 0, 1]))
57 .trust_proxy(IpAddr::from(std::net::Ipv6Addr::LOCALHOST))
58 }
59
60 fn is_trusted(&self, request: &Request) -> bool {
61 let Some(peer) = peer_addr_from_extensions(request.extensions()) else {
62 return false;
63 };
64 self.trusted_ips.iter().any(|addr| *addr == peer.ip())
65 || self
66 .trusted_cidrs
67 .iter()
68 .any(|network| network.contains(&peer.ip()))
69 }
70
71 fn forwarded_value<'a>(request: &'a Request, name: &'static str) -> Option<&'a str> {
72 request
73 .headers()
74 .get(name)
75 .and_then(|value| value.to_str().ok())
76 .and_then(|value| value.split(',').next())
77 .map(str::trim)
78 .filter(|value| !value.is_empty())
79 }
80}
81
82impl Default for ProxyHeaders {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl Middleware for ProxyHeaders {
89 fn handle(&self, mut request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
90 if !self.is_trusted(&request) {
91 return next.run(request);
92 }
93
94 if let Some(forwarded_host) = Self::forwarded_value(&request, FORWARDED_HOST) {
95 if let Ok(value) = http::HeaderValue::from_str(forwarded_host) {
96 request.headers_mut().insert(HOST, value);
97 }
98 }
99
100 if let Some(forwarded_proto) = Self::forwarded_value(&request, FORWARDED_PROTO) {
101 let scheme = if forwarded_proto.eq_ignore_ascii_case("https") {
102 Some(RequestScheme::Https)
103 } else if forwarded_proto.eq_ignore_ascii_case("http") {
104 Some(RequestScheme::Http)
105 } else {
106 None
107 };
108
109 if let Some(scheme) = scheme {
110 request.extensions_mut().insert(scheme);
111 } else {
112 warn!("tork: ignoring unsupported X-Forwarded-Proto value");
113 }
114 }
115 next.run(request)
116 }
117
118 fn name(&self) -> &'static str {
119 "ProxyHeaders"
120 }
121
122 fn duplicate_policy(&self) -> DuplicatePolicy {
123 DuplicatePolicy::Reject
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn builtin_metadata_is_stable() {
133 let middleware = ProxyHeaders::new();
134 assert_eq!(middleware.name(), "ProxyHeaders");
135 assert_eq!(middleware.duplicate_policy(), DuplicatePolicy::Reject);
136 }
137
138 #[test]
139 fn default_impl_uses_new() {
140 let middleware: ProxyHeaders = Default::default();
142 assert!(middleware.trusted_ips.is_empty());
143 assert!(middleware.trusted_cidrs.is_empty());
144 }
145
146 #[test]
147 fn trust_builders_register_expected_networks() {
148 let middleware = ProxyHeaders::new()
149 .trust_proxy(IpAddr::from([10, 0, 0, 1]))
150 .trust_cidr("10.0.0.0/24".parse().unwrap())
151 .trust_loopback();
152
153 assert!(middleware
154 .trusted_ips
155 .contains(&IpAddr::from([10, 0, 0, 1])));
156 assert!(middleware
157 .trusted_cidrs
158 .contains(&"10.0.0.0/24".parse().unwrap()));
159 assert!(middleware
160 .trusted_ips
161 .contains(&IpAddr::from([127, 0, 0, 1])));
162 }
163}