Skip to main content

rustyclaw_core/security/
ssrf.rs

1//! SSRF (Server-Side Request Forgery) protection
2//!
3//! Validates URLs before making HTTP requests to prevent:
4//! - Access to private IP ranges
5//! - Access to localhost
6//! - Access to cloud metadata endpoints
7//! - DNS rebinding attacks
8//! - Unicode homograph attacks in domains
9
10use ipnetwork::IpNetwork;
11use std::net::{IpAddr, ToSocketAddrs};
12use std::str::FromStr;
13use tracing::warn;
14
15/// SSRF validator with configurable blocked CIDR ranges
16#[derive(Debug, Clone)]
17pub struct SsrfValidator {
18    /// List of blocked IP ranges (CIDR notation)
19    blocked_ranges: Vec<IpNetwork>,
20    /// Whether to allow private IPs (override for trusted environments)
21    #[allow(dead_code)]
22    allow_private_ips: bool,
23}
24
25impl Default for SsrfValidator {
26    fn default() -> Self {
27        Self::new(false)
28    }
29}
30
31impl SsrfValidator {
32    /// Create a new SSRF validator with default blocked ranges
33    pub fn new(allow_private_ips: bool) -> Self {
34        let blocked_ranges = if allow_private_ips {
35            vec![
36                // Only block cloud metadata endpoints if private IPs are allowed
37                IpNetwork::from_str("169.254.169.254/32").unwrap(), // AWS/GCP/Azure metadata
38            ]
39        } else {
40            vec![
41                // Private IP ranges (RFC 1918)
42                IpNetwork::from_str("10.0.0.0/8").unwrap(),
43                IpNetwork::from_str("172.16.0.0/12").unwrap(),
44                IpNetwork::from_str("192.168.0.0/16").unwrap(),
45                // Localhost
46                IpNetwork::from_str("127.0.0.0/8").unwrap(),
47                IpNetwork::from_str("::1/128").unwrap(),
48                // Link-local
49                IpNetwork::from_str("169.254.0.0/16").unwrap(),
50                IpNetwork::from_str("fe80::/10").unwrap(),
51                // Loopback
52                IpNetwork::from_str("::ffff:127.0.0.0/104").unwrap(), // IPv4-mapped IPv6 loopback
53                // Other reserved ranges
54                IpNetwork::from_str("0.0.0.0/8").unwrap(),
55                IpNetwork::from_str("100.64.0.0/10").unwrap(), // Carrier-grade NAT
56                IpNetwork::from_str("192.0.0.0/24").unwrap(),  // IETF protocol assignments
57                IpNetwork::from_str("192.0.2.0/24").unwrap(),  // TEST-NET-1
58                IpNetwork::from_str("198.18.0.0/15").unwrap(), // Benchmarking
59                IpNetwork::from_str("198.51.100.0/24").unwrap(), // TEST-NET-2
60                IpNetwork::from_str("203.0.113.0/24").unwrap(), // TEST-NET-3
61                IpNetwork::from_str("224.0.0.0/4").unwrap(),   // Multicast
62                IpNetwork::from_str("240.0.0.0/4").unwrap(),   // Reserved
63                IpNetwork::from_str("255.255.255.255/32").unwrap(), // Broadcast
64            ]
65        };
66
67        Self {
68            blocked_ranges,
69            allow_private_ips,
70        }
71    }
72
73    /// Add a custom blocked CIDR range
74    pub fn add_blocked_range(&mut self, cidr: &str) -> Result<(), String> {
75        let network =
76            IpNetwork::from_str(cidr).map_err(|e| format!("Invalid CIDR notation: {}", e))?;
77        self.blocked_ranges.push(network);
78        Ok(())
79    }
80
81    /// Validate a URL for SSRF vulnerabilities
82    pub fn validate_url(&self, url: &str) -> Result<(), String> {
83        // Parse the URL
84        let parsed_url = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
85
86        // 1. Validate scheme (only http/https allowed)
87        let scheme = parsed_url.scheme();
88        if scheme != "http" && scheme != "https" {
89            return Err(format!(
90                "Invalid URL scheme '{}': only http:// and https:// are allowed",
91                scheme
92            ));
93        }
94
95        // 2. Get the host
96        let host = parsed_url
97            .host_str()
98            .ok_or_else(|| "URL has no host".to_string())?;
99
100        // 3. Check for Unicode homograph attacks (non-ASCII characters in domain)
101        if !host.is_ascii() {
102            return Err(format!(
103                "Security: Domain contains non-ASCII characters (potential homograph attack): {}",
104                host
105            ));
106        }
107
108        // 4. Resolve hostname to IP addresses
109        let socket_addr_str = if let Some(port) = parsed_url.port() {
110            format!("{}:{}", host, port)
111        } else {
112            // Use default ports for scheme
113            let default_port = if scheme == "https" { 443 } else { 80 };
114            format!("{}:{}", host, default_port)
115        };
116
117        let ip_addrs: Vec<IpAddr> = socket_addr_str
118            .to_socket_addrs()
119            .map_err(|e| format!("Failed to resolve hostname '{}': {}", host, e))?
120            .map(|sa| sa.ip())
121            .collect();
122
123        if ip_addrs.is_empty() {
124            return Err(format!("Hostname '{}' resolved to no IP addresses", host));
125        }
126
127        // 5. Check all resolved IPs against blocked ranges
128        for ip in &ip_addrs {
129            self.validate_ip(ip)?;
130        }
131
132        // 6. DNS rebinding protection: resolve again and verify IPs haven't changed
133        // This helps detect time-of-check-time-of-use attacks
134        let recheck_ips: Vec<IpAddr> = socket_addr_str
135            .to_socket_addrs()
136            .map_err(|e| format!("DNS recheck failed for '{}': {}", host, e))?
137            .map(|sa| sa.ip())
138            .collect();
139
140        // Verify that all IPs from both resolutions are safe
141        for ip in &recheck_ips {
142            self.validate_ip(ip)?;
143        }
144
145        // Check if IP sets differ (potential DNS rebinding)
146        if ip_addrs.len() != recheck_ips.len()
147            || !ip_addrs.iter().all(|ip| recheck_ips.contains(ip))
148        {
149            warn!(
150                host = %host,
151                initial_ips = ?ip_addrs,
152                recheck_ips = ?recheck_ips,
153                "DNS resolution changed between checks — possible DNS rebinding"
154            );
155            // Allow it but log the warning - legitimate round-robin DNS can cause this
156        }
157
158        Ok(())
159    }
160
161    /// Validate a single IP address against blocked ranges
162    fn validate_ip(&self, ip: &IpAddr) -> Result<(), String> {
163        for blocked_range in &self.blocked_ranges {
164            if blocked_range.contains(*ip) {
165                return Err(format!(
166                    "Security: Access to {} is blocked (matches blocked range {})",
167                    ip, blocked_range
168                ));
169            }
170        }
171        Ok(())
172    }
173
174    /// Check if a URL would be blocked (non-failing version for testing)
175    pub fn is_blocked(&self, url: &str) -> bool {
176        self.validate_url(url).is_err()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_blocks_private_ips() {
186        let validator = SsrfValidator::new(false);
187
188        // Private IPs should be blocked
189        assert!(validator.is_blocked("http://192.168.1.1/"));
190        assert!(validator.is_blocked("http://10.0.0.1/"));
191        assert!(validator.is_blocked("http://172.16.0.1/"));
192    }
193
194    #[test]
195    fn test_blocks_localhost() {
196        let validator = SsrfValidator::new(false);
197
198        assert!(validator.is_blocked("http://127.0.0.1/"));
199        assert!(validator.is_blocked("http://localhost/"));
200    }
201
202    #[test]
203    fn test_blocks_cloud_metadata() {
204        let validator = SsrfValidator::new(false);
205
206        assert!(validator.is_blocked("http://169.254.169.254/latest/meta-data/"));
207    }
208
209    #[test]
210    fn test_blocks_invalid_schemes() {
211        let validator = SsrfValidator::new(false);
212
213        assert!(validator.is_blocked("file:///etc/passwd"));
214        assert!(validator.is_blocked("ftp://example.com/"));
215        assert!(validator.is_blocked("javascript:alert(1)"));
216    }
217
218    #[test]
219    fn test_allows_public_urls() {
220        let validator = SsrfValidator::new(false);
221
222        // These should succeed (though DNS resolution might fail in tests)
223        // We're just testing the validation logic, not actual network access
224        let result = validator.validate_url("https://example.com/");
225        // May fail due to DNS in test environment, but shouldn't fail for SSRF reasons
226        if let Err(e) = result {
227            assert!(!e.contains("Security:"), "Should not be a security error");
228        }
229    }
230
231    #[test]
232    fn test_allow_private_ips_override() {
233        let validator = SsrfValidator::new(true);
234
235        // With allow_private_ips=true, private IPs should be allowed
236        // (but metadata endpoints still blocked)
237        // Note: This will fail DNS resolution in tests, but that's expected
238        let result = validator.validate_url("http://192.168.1.1/");
239        if let Err(e) = result {
240            // Should fail DNS resolution, not security check
241            assert!(!e.contains("Security:") || e.contains("Failed to resolve"));
242        }
243    }
244
245    #[test]
246    fn test_custom_blocked_range() {
247        let mut validator = SsrfValidator::new(false);
248        validator.add_blocked_range("8.8.8.0/24").unwrap();
249
250        // 8.8.8.8 should now be blocked
251        assert!(validator.is_blocked("http://8.8.8.8/"));
252    }
253}