warpdrive_proxy/middleware/
trusted_ranges.rs

1//! Trusted IP ranges middleware for bypassing rate limiting and concurrency limits
2//!
3//! This middleware provides two features:
4//! 1. **Client IP Normalization** - Extract real client IP from proxy headers
5//! 2. **Trusted IP Ranges** - Bypass protections for requests from trusted proxies/CDNs
6//!
7//! # Architecture
8//!
9//! When enabled, this middleware:
10//! - Loads CIDR ranges from a file on startup
11//! - Checks if request source IP matches any trusted range
12//! - Extracts real client IP from configured header (if trusted)
13//! - Sets `ctx.trusted_source = true` for trusted requests
14//! - Sets `ctx.real_client_ip` with normalized IP
15//!
16//! # Configuration
17//!
18//! ```bash
19//! WARPDRIVE_CLIENT_IP_HEADER=CF-Connecting-IP  # Header containing real client IP
20//! WARPDRIVE_TRUSTED_RANGES_FILE=/etc/warpdrive/trusted-ips.txt  # CIDR ranges file
21//! ```
22//!
23//! # Example trusted-ips.txt
24//!
25//! ```text
26//! # Cloudflare IPv4 ranges
27//! 173.245.48.0/20
28//! 103.21.244.0/22
29//!
30//! # Cloudflare IPv6 ranges
31//! 2606:4700::/32
32//!
33//! # Internal load balancers
34//! 10.0.0.0/8
35//! 192.168.0.0/16
36//! ```
37
38use async_trait::async_trait;
39use ipnet::IpNet;
40use pingora::prelude::*;
41use std::fs;
42use std::net::IpAddr;
43use std::path::{Path, PathBuf};
44use std::sync::Arc;
45use tracing::{debug, error, info, warn};
46
47use super::{Middleware, MiddlewareContext};
48
49/// Middleware for trusted IP ranges and client IP normalization
50pub struct TrustedRangesMiddleware {
51    /// Loaded CIDR ranges for trusted sources
52    ranges: Arc<Vec<IpNet>>,
53    /// Optional header name to extract real client IP from
54    client_ip_header: Option<String>,
55    /// Path to ranges file (for diagnostics)
56    #[allow(dead_code)]
57    ranges_file: Option<PathBuf>,
58}
59
60impl TrustedRangesMiddleware {
61    /// Create a new middleware with ranges from a file
62    ///
63    /// # Arguments
64    /// * `ranges_file` - Path to CIDR ranges file
65    /// * `client_ip_header` - Optional header name for real client IP
66    pub fn new(ranges_file: Option<PathBuf>, client_ip_header: Option<String>) -> Result<Self> {
67        let ranges = if let Some(ref path) = ranges_file {
68            match load_ranges_from_file(path) {
69                Ok(r) => {
70                    info!("Loaded {} trusted IP ranges from {:?}", r.len(), path);
71                    Arc::new(r)
72                }
73                Err(e) => {
74                    error!("Failed to load trusted ranges from {:?}: {}", path, e);
75                    Arc::new(vec![])
76                }
77            }
78        } else {
79            info!("No trusted ranges file configured");
80            Arc::new(vec![])
81        };
82
83        Ok(Self {
84            ranges,
85            client_ip_header,
86            ranges_file,
87        })
88    }
89
90    /// Check if an IP address is in the trusted ranges
91    fn is_trusted(&self, ip: &IpAddr) -> bool {
92        self.ranges.iter().any(|net| net.contains(ip))
93    }
94
95    /// Extract real client IP from session
96    ///
97    /// If `client_ip_header` is configured and the source IP is trusted,
98    /// extract IP from the header. Otherwise, fall back to socket IP.
99    ///
100    /// # Security
101    /// Only trusts the header if source IP is in trusted ranges to prevent spoofing.
102    fn get_real_client_ip(&self, session: &Session, source_is_trusted: bool) -> IpAddr {
103        // Only trust custom headers if source is from trusted range
104        if source_is_trusted {
105            if let Some(ref header_name) = self.client_ip_header {
106                if let Some(header_value) = session.req_header().headers.get(header_name) {
107                    if let Ok(value_str) = header_value.to_str() {
108                        // Try parsing as single IP
109                        if let Ok(ip) = value_str.parse::<IpAddr>() {
110                            debug!("Using real client IP from {}: {}", header_name, ip);
111                            return ip;
112                        }
113
114                        // Handle X-Forwarded-For format (comma-separated list)
115                        if header_name.eq_ignore_ascii_case("x-forwarded-for") {
116                            if let Some(first_ip) = value_str.split(',').next() {
117                                if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
118                                    debug!("Using first IP from X-Forwarded-For chain: {}", ip);
119                                    return ip;
120                                }
121                            }
122                        }
123
124                        warn!(
125                            "Failed to parse IP from {} header: {}",
126                            header_name, value_str
127                        );
128                    }
129                }
130            }
131        }
132
133        // Fallback to socket IP
134        session
135            .client_addr()
136            .and_then(|addr| addr.as_inet().map(|inet| inet.ip()))
137            .unwrap_or_else(|| "0.0.0.0".parse().unwrap())
138    }
139}
140
141#[async_trait]
142impl Middleware for TrustedRangesMiddleware {
143    async fn request_filter(
144        &self,
145        session: &mut Session,
146        ctx: &mut MiddlewareContext,
147    ) -> Result<()> {
148        // Get source IP from socket
149        let source_ip = session
150            .client_addr()
151            .and_then(|addr| addr.as_inet().map(|inet| inet.ip()))
152            .unwrap_or_else(|| "0.0.0.0".parse().unwrap());
153
154        // Check if source is trusted
155        let is_trusted = self.is_trusted(&source_ip);
156
157        if is_trusted {
158            debug!("Request from trusted source: {}", source_ip);
159            ctx.trusted_source = true;
160        }
161
162        // Extract real client IP (honors trusted status for security)
163        let real_ip = self.get_real_client_ip(session, is_trusted);
164        ctx.real_client_ip = real_ip;
165
166        if is_trusted && real_ip != source_ip {
167            debug!(
168                "Normalized client IP: {} -> {} (via {})",
169                source_ip,
170                real_ip,
171                self.client_ip_header.as_deref().unwrap_or("socket")
172            );
173        }
174
175        Ok(())
176    }
177}
178
179/// Load CIDR ranges from a file
180///
181/// File format:
182/// - One CIDR range per line
183/// - Lines starting with `#` are comments
184/// - Empty lines are ignored
185/// - Invalid lines are logged and skipped
186fn load_ranges_from_file(path: &Path) -> Result<Vec<IpNet>> {
187    let content = fs::read_to_string(path).map_err(|e| {
188        Error::explain(
189            ErrorType::ReadError,
190            format!("Failed to read trusted ranges file: {}", e),
191        )
192    })?;
193
194    let mut ranges = Vec::new();
195    for (line_num, line) in content.lines().enumerate() {
196        let trimmed = line.trim();
197
198        // Skip empty lines and comments
199        if trimmed.is_empty() || trimmed.starts_with('#') {
200            continue;
201        }
202
203        // Parse CIDR range
204        match trimmed.parse::<IpNet>() {
205            Ok(net) => ranges.push(net),
206            Err(e) => {
207                warn!(
208                    "Invalid CIDR range at line {}: {} ({})",
209                    line_num + 1,
210                    trimmed,
211                    e
212                );
213            }
214        }
215    }
216
217    Ok(ranges)
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_parse_ipv4_ranges() {
226        let content = "# Cloudflare\n173.245.48.0/20\n103.21.244.0/22\n\n# Empty line above\n";
227        let path = std::env::temp_dir().join("test_ranges_ipv4.txt");
228        fs::write(&path, content).unwrap();
229
230        let ranges = load_ranges_from_file(&path).unwrap();
231        assert_eq!(ranges.len(), 2);
232        assert!(ranges[0].contains(&"173.245.48.100".parse::<IpAddr>().unwrap()));
233        assert!(ranges[1].contains(&"103.21.244.50".parse::<IpAddr>().unwrap()));
234
235        fs::remove_file(&path).unwrap();
236    }
237
238    #[test]
239    fn test_parse_ipv6_ranges() {
240        let content = "2606:4700::/32\n2405:8100::/32";
241        let path = std::env::temp_dir().join("test_ranges_ipv6.txt");
242        fs::write(&path, content).unwrap();
243
244        let ranges = load_ranges_from_file(&path).unwrap();
245        assert_eq!(ranges.len(), 2);
246        assert!(ranges[0].contains(&"2606:4700::1".parse::<IpAddr>().unwrap()));
247        assert!(ranges[1].contains(&"2405:8100::1".parse::<IpAddr>().unwrap()));
248
249        fs::remove_file(&path).unwrap();
250    }
251
252    #[test]
253    fn test_invalid_ranges_skipped() {
254        let content = "173.245.48.0/20\ninvalid-cidr\n192.168.0.0/16\n";
255        let path = std::env::temp_dir().join("test_ranges_invalid.txt");
256        fs::write(&path, content).unwrap();
257
258        let ranges = load_ranges_from_file(&path).unwrap();
259        assert_eq!(ranges.len(), 2); // Invalid line skipped
260
261        fs::remove_file(&path).unwrap();
262    }
263
264    #[test]
265    fn test_is_trusted() {
266        let content = "173.245.48.0/20\n10.0.0.0/8";
267        let path = std::env::temp_dir().join("test_is_trusted.txt");
268        fs::write(&path, content).unwrap();
269
270        let middleware = TrustedRangesMiddleware::new(Some(path.clone()), None).unwrap();
271
272        assert!(middleware.is_trusted(&"173.245.48.100".parse::<IpAddr>().unwrap()));
273        assert!(middleware.is_trusted(&"10.0.1.50".parse::<IpAddr>().unwrap()));
274        assert!(!middleware.is_trusted(&"1.2.3.4".parse::<IpAddr>().unwrap()));
275
276        fs::remove_file(&path).unwrap();
277    }
278}