rustyclaw_core/security/
ssrf.rs1use ipnetwork::IpNetwork;
11use std::net::{IpAddr, ToSocketAddrs};
12use std::str::FromStr;
13use tracing::warn;
14
15#[derive(Debug, Clone)]
17pub struct SsrfValidator {
18 blocked_ranges: Vec<IpNetwork>,
20 #[allow(dead_code)]
22 allow_private_ips: bool,
23}
24
25impl Default for SsrfValidator {
26 fn default() -> Self {
27 Self::new(false)
28 }
29}
30
31impl SsrfValidator {
32 pub fn new(allow_private_ips: bool) -> Self {
34 let blocked_ranges = if allow_private_ips {
35 vec![
36 IpNetwork::from_str("169.254.169.254/32").unwrap(), ]
39 } else {
40 vec![
41 IpNetwork::from_str("10.0.0.0/8").unwrap(),
43 IpNetwork::from_str("172.16.0.0/12").unwrap(),
44 IpNetwork::from_str("192.168.0.0/16").unwrap(),
45 IpNetwork::from_str("127.0.0.0/8").unwrap(),
47 IpNetwork::from_str("::1/128").unwrap(),
48 IpNetwork::from_str("169.254.0.0/16").unwrap(),
50 IpNetwork::from_str("fe80::/10").unwrap(),
51 IpNetwork::from_str("::ffff:127.0.0.0/104").unwrap(), IpNetwork::from_str("0.0.0.0/8").unwrap(),
55 IpNetwork::from_str("100.64.0.0/10").unwrap(), IpNetwork::from_str("192.0.0.0/24").unwrap(), IpNetwork::from_str("192.0.2.0/24").unwrap(), IpNetwork::from_str("198.18.0.0/15").unwrap(), IpNetwork::from_str("198.51.100.0/24").unwrap(), IpNetwork::from_str("203.0.113.0/24").unwrap(), IpNetwork::from_str("224.0.0.0/4").unwrap(), IpNetwork::from_str("240.0.0.0/4").unwrap(), IpNetwork::from_str("255.255.255.255/32").unwrap(), ]
65 };
66
67 Self {
68 blocked_ranges,
69 allow_private_ips,
70 }
71 }
72
73 pub fn add_blocked_range(&mut self, cidr: &str) -> Result<(), String> {
75 let network =
76 IpNetwork::from_str(cidr).map_err(|e| format!("Invalid CIDR notation: {}", e))?;
77 self.blocked_ranges.push(network);
78 Ok(())
79 }
80
81 pub fn validate_url(&self, url: &str) -> Result<(), String> {
83 let parsed_url = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
85
86 let scheme = parsed_url.scheme();
88 if scheme != "http" && scheme != "https" {
89 return Err(format!(
90 "Invalid URL scheme '{}': only http:// and https:// are allowed",
91 scheme
92 ));
93 }
94
95 let host = parsed_url
97 .host_str()
98 .ok_or_else(|| "URL has no host".to_string())?;
99
100 if !host.is_ascii() {
102 return Err(format!(
103 "Security: Domain contains non-ASCII characters (potential homograph attack): {}",
104 host
105 ));
106 }
107
108 let socket_addr_str = if let Some(port) = parsed_url.port() {
110 format!("{}:{}", host, port)
111 } else {
112 let default_port = if scheme == "https" { 443 } else { 80 };
114 format!("{}:{}", host, default_port)
115 };
116
117 let ip_addrs: Vec<IpAddr> = socket_addr_str
118 .to_socket_addrs()
119 .map_err(|e| format!("Failed to resolve hostname '{}': {}", host, e))?
120 .map(|sa| sa.ip())
121 .collect();
122
123 if ip_addrs.is_empty() {
124 return Err(format!("Hostname '{}' resolved to no IP addresses", host));
125 }
126
127 for ip in &ip_addrs {
129 self.validate_ip(ip)?;
130 }
131
132 let recheck_ips: Vec<IpAddr> = socket_addr_str
135 .to_socket_addrs()
136 .map_err(|e| format!("DNS recheck failed for '{}': {}", host, e))?
137 .map(|sa| sa.ip())
138 .collect();
139
140 for ip in &recheck_ips {
142 self.validate_ip(ip)?;
143 }
144
145 if ip_addrs.len() != recheck_ips.len()
147 || !ip_addrs.iter().all(|ip| recheck_ips.contains(ip))
148 {
149 warn!(
150 host = %host,
151 initial_ips = ?ip_addrs,
152 recheck_ips = ?recheck_ips,
153 "DNS resolution changed between checks — possible DNS rebinding"
154 );
155 }
157
158 Ok(())
159 }
160
161 fn validate_ip(&self, ip: &IpAddr) -> Result<(), String> {
163 for blocked_range in &self.blocked_ranges {
164 if blocked_range.contains(*ip) {
165 return Err(format!(
166 "Security: Access to {} is blocked (matches blocked range {})",
167 ip, blocked_range
168 ));
169 }
170 }
171 Ok(())
172 }
173
174 pub fn is_blocked(&self, url: &str) -> bool {
176 self.validate_url(url).is_err()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_blocks_private_ips() {
186 let validator = SsrfValidator::new(false);
187
188 assert!(validator.is_blocked("http://192.168.1.1/"));
190 assert!(validator.is_blocked("http://10.0.0.1/"));
191 assert!(validator.is_blocked("http://172.16.0.1/"));
192 }
193
194 #[test]
195 fn test_blocks_localhost() {
196 let validator = SsrfValidator::new(false);
197
198 assert!(validator.is_blocked("http://127.0.0.1/"));
199 assert!(validator.is_blocked("http://localhost/"));
200 }
201
202 #[test]
203 fn test_blocks_cloud_metadata() {
204 let validator = SsrfValidator::new(false);
205
206 assert!(validator.is_blocked("http://169.254.169.254/latest/meta-data/"));
207 }
208
209 #[test]
210 fn test_blocks_invalid_schemes() {
211 let validator = SsrfValidator::new(false);
212
213 assert!(validator.is_blocked("file:///etc/passwd"));
214 assert!(validator.is_blocked("ftp://example.com/"));
215 assert!(validator.is_blocked("javascript:alert(1)"));
216 }
217
218 #[test]
219 fn test_allows_public_urls() {
220 let validator = SsrfValidator::new(false);
221
222 let result = validator.validate_url("https://example.com/");
225 if let Err(e) = result {
227 assert!(!e.contains("Security:"), "Should not be a security error");
228 }
229 }
230
231 #[test]
232 fn test_allow_private_ips_override() {
233 let validator = SsrfValidator::new(true);
234
235 let result = validator.validate_url("http://192.168.1.1/");
239 if let Err(e) = result {
240 assert!(!e.contains("Security:") || e.contains("Failed to resolve"));
242 }
243 }
244
245 #[test]
246 fn test_custom_blocked_range() {
247 let mut validator = SsrfValidator::new(false);
248 validator.add_blocked_range("8.8.8.0/24").unwrap();
249
250 assert!(validator.is_blocked("http://8.8.8.8/"));
252 }
253}