warpdrive_proxy/middleware/
trusted_ranges.rs1use async_trait::async_trait;
39use ipnet::IpNet;
40use pingora::prelude::*;
41use std::fs;
42use std::net::IpAddr;
43use std::path::{Path, PathBuf};
44use std::sync::Arc;
45use tracing::{debug, error, info, warn};
46
47use super::{Middleware, MiddlewareContext};
48
49pub struct TrustedRangesMiddleware {
51 ranges: Arc<Vec<IpNet>>,
53 client_ip_header: Option<String>,
55 #[allow(dead_code)]
57 ranges_file: Option<PathBuf>,
58}
59
60impl TrustedRangesMiddleware {
61 pub fn new(ranges_file: Option<PathBuf>, client_ip_header: Option<String>) -> Result<Self> {
67 let ranges = if let Some(ref path) = ranges_file {
68 match load_ranges_from_file(path) {
69 Ok(r) => {
70 info!("Loaded {} trusted IP ranges from {:?}", r.len(), path);
71 Arc::new(r)
72 }
73 Err(e) => {
74 error!("Failed to load trusted ranges from {:?}: {}", path, e);
75 Arc::new(vec![])
76 }
77 }
78 } else {
79 info!("No trusted ranges file configured");
80 Arc::new(vec![])
81 };
82
83 Ok(Self {
84 ranges,
85 client_ip_header,
86 ranges_file,
87 })
88 }
89
90 fn is_trusted(&self, ip: &IpAddr) -> bool {
92 self.ranges.iter().any(|net| net.contains(ip))
93 }
94
95 fn get_real_client_ip(&self, session: &Session, source_is_trusted: bool) -> IpAddr {
103 if source_is_trusted {
105 if let Some(ref header_name) = self.client_ip_header {
106 if let Some(header_value) = session.req_header().headers.get(header_name) {
107 if let Ok(value_str) = header_value.to_str() {
108 if let Ok(ip) = value_str.parse::<IpAddr>() {
110 debug!("Using real client IP from {}: {}", header_name, ip);
111 return ip;
112 }
113
114 if header_name.eq_ignore_ascii_case("x-forwarded-for") {
116 if let Some(first_ip) = value_str.split(',').next() {
117 if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
118 debug!("Using first IP from X-Forwarded-For chain: {}", ip);
119 return ip;
120 }
121 }
122 }
123
124 warn!(
125 "Failed to parse IP from {} header: {}",
126 header_name, value_str
127 );
128 }
129 }
130 }
131 }
132
133 session
135 .client_addr()
136 .and_then(|addr| addr.as_inet().map(|inet| inet.ip()))
137 .unwrap_or_else(|| "0.0.0.0".parse().unwrap())
138 }
139}
140
141#[async_trait]
142impl Middleware for TrustedRangesMiddleware {
143 async fn request_filter(
144 &self,
145 session: &mut Session,
146 ctx: &mut MiddlewareContext,
147 ) -> Result<()> {
148 let source_ip = session
150 .client_addr()
151 .and_then(|addr| addr.as_inet().map(|inet| inet.ip()))
152 .unwrap_or_else(|| "0.0.0.0".parse().unwrap());
153
154 let is_trusted = self.is_trusted(&source_ip);
156
157 if is_trusted {
158 debug!("Request from trusted source: {}", source_ip);
159 ctx.trusted_source = true;
160 }
161
162 let real_ip = self.get_real_client_ip(session, is_trusted);
164 ctx.real_client_ip = real_ip;
165
166 if is_trusted && real_ip != source_ip {
167 debug!(
168 "Normalized client IP: {} -> {} (via {})",
169 source_ip,
170 real_ip,
171 self.client_ip_header.as_deref().unwrap_or("socket")
172 );
173 }
174
175 Ok(())
176 }
177}
178
179fn load_ranges_from_file(path: &Path) -> Result<Vec<IpNet>> {
187 let content = fs::read_to_string(path).map_err(|e| {
188 Error::explain(
189 ErrorType::ReadError,
190 format!("Failed to read trusted ranges file: {}", e),
191 )
192 })?;
193
194 let mut ranges = Vec::new();
195 for (line_num, line) in content.lines().enumerate() {
196 let trimmed = line.trim();
197
198 if trimmed.is_empty() || trimmed.starts_with('#') {
200 continue;
201 }
202
203 match trimmed.parse::<IpNet>() {
205 Ok(net) => ranges.push(net),
206 Err(e) => {
207 warn!(
208 "Invalid CIDR range at line {}: {} ({})",
209 line_num + 1,
210 trimmed,
211 e
212 );
213 }
214 }
215 }
216
217 Ok(ranges)
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_parse_ipv4_ranges() {
226 let content = "# Cloudflare\n173.245.48.0/20\n103.21.244.0/22\n\n# Empty line above\n";
227 let path = std::env::temp_dir().join("test_ranges_ipv4.txt");
228 fs::write(&path, content).unwrap();
229
230 let ranges = load_ranges_from_file(&path).unwrap();
231 assert_eq!(ranges.len(), 2);
232 assert!(ranges[0].contains(&"173.245.48.100".parse::<IpAddr>().unwrap()));
233 assert!(ranges[1].contains(&"103.21.244.50".parse::<IpAddr>().unwrap()));
234
235 fs::remove_file(&path).unwrap();
236 }
237
238 #[test]
239 fn test_parse_ipv6_ranges() {
240 let content = "2606:4700::/32\n2405:8100::/32";
241 let path = std::env::temp_dir().join("test_ranges_ipv6.txt");
242 fs::write(&path, content).unwrap();
243
244 let ranges = load_ranges_from_file(&path).unwrap();
245 assert_eq!(ranges.len(), 2);
246 assert!(ranges[0].contains(&"2606:4700::1".parse::<IpAddr>().unwrap()));
247 assert!(ranges[1].contains(&"2405:8100::1".parse::<IpAddr>().unwrap()));
248
249 fs::remove_file(&path).unwrap();
250 }
251
252 #[test]
253 fn test_invalid_ranges_skipped() {
254 let content = "173.245.48.0/20\ninvalid-cidr\n192.168.0.0/16\n";
255 let path = std::env::temp_dir().join("test_ranges_invalid.txt");
256 fs::write(&path, content).unwrap();
257
258 let ranges = load_ranges_from_file(&path).unwrap();
259 assert_eq!(ranges.len(), 2); fs::remove_file(&path).unwrap();
262 }
263
264 #[test]
265 fn test_is_trusted() {
266 let content = "173.245.48.0/20\n10.0.0.0/8";
267 let path = std::env::temp_dir().join("test_is_trusted.txt");
268 fs::write(&path, content).unwrap();
269
270 let middleware = TrustedRangesMiddleware::new(Some(path.clone()), None).unwrap();
271
272 assert!(middleware.is_trusted(&"173.245.48.100".parse::<IpAddr>().unwrap()));
273 assert!(middleware.is_trusted(&"10.0.1.50".parse::<IpAddr>().unwrap()));
274 assert!(!middleware.is_trusted(&"1.2.3.4".parse::<IpAddr>().unwrap()));
275
276 fs::remove_file(&path).unwrap();
277 }
278}