url_preview/
security.rs

1use crate::error::PreviewError;
2use std::collections::HashSet;
3use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4use url::Url;
5
6/// Configuration for URL validation
7#[derive(Debug, Clone)]
8pub struct UrlValidationConfig {
9    /// Allowed URL schemes (default: ["http", "https"])
10    pub allowed_schemes: HashSet<String>,
11    /// Block private/local IP addresses (default: true)
12    pub block_private_ips: bool,
13    /// Block localhost addresses (default: true)
14    pub block_localhost: bool,
15    /// Domain blacklist
16    pub blocked_domains: HashSet<String>,
17    /// Domain whitelist (if not empty, only these domains are allowed)
18    pub allowed_domains: HashSet<String>,
19    /// Maximum number of redirects allowed
20    pub max_redirects: usize,
21}
22
23impl Default for UrlValidationConfig {
24    fn default() -> Self {
25        let mut allowed_schemes = HashSet::new();
26        allowed_schemes.insert("http".to_string());
27        allowed_schemes.insert("https".to_string());
28
29        Self {
30            allowed_schemes,
31            block_private_ips: true,
32            block_localhost: true,
33            blocked_domains: HashSet::new(),
34            allowed_domains: HashSet::new(),
35            max_redirects: 10,
36        }
37    }
38}
39
40/// Validates a URL according to security policies
41#[derive(Clone)]
42pub struct UrlValidator {
43    config: UrlValidationConfig,
44}
45
46impl UrlValidator {
47    pub fn new(config: UrlValidationConfig) -> Self {
48        Self { config }
49    }
50
51    pub fn with_default_config() -> Self {
52        Self::new(UrlValidationConfig::default())
53    }
54
55    /// Validates a URL string
56    pub fn validate(&self, url_str: &str) -> Result<Url, PreviewError> {
57        // Parse the URL
58        let url = Url::parse(url_str).map_err(PreviewError::UrlParseError)?;
59
60        // Check scheme
61        if !self.config.allowed_schemes.contains(url.scheme()) {
62            return Err(PreviewError::InvalidUrlScheme(url.scheme().to_string()));
63        }
64
65        // Extract host
66        let host = url
67            .host_str()
68            .ok_or_else(|| PreviewError::InvalidUrl("No host in URL".to_string()))?;
69
70        // Check domain whitelist/blacklist
71        if !self.config.allowed_domains.is_empty() {
72            if !self.is_domain_allowed(host) {
73                return Err(PreviewError::DomainNotAllowed(host.to_string()));
74            }
75        } else if self.is_domain_blocked(host) {
76            return Err(PreviewError::DomainBlocked(host.to_string()));
77        }
78
79        // Check for localhost
80        if self.config.block_localhost && self.is_localhost(host) {
81            return Err(PreviewError::LocalhostBlocked);
82        }
83
84        // Check for private IPs if host is an IP address
85        if self.config.block_private_ips {
86            // Handle IPv6 addresses which may be wrapped in brackets
87            let ip_str = if host.starts_with('[') && host.ends_with(']') {
88                &host[1..host.len() - 1]
89            } else {
90                host
91            };
92            
93            if let Ok(ip) = ip_str.parse::<IpAddr>() {
94                if self.is_private_ip(&ip) {
95                    return Err(PreviewError::PrivateIpBlocked(ip.to_string()));
96                }
97            }
98        }
99
100        Ok(url)
101    }
102
103    fn is_domain_allowed(&self, host: &str) -> bool {
104        self.config
105            .allowed_domains
106            .iter()
107            .any(|allowed| host == allowed || host.ends_with(&format!(".{allowed}")))
108    }
109
110    fn is_domain_blocked(&self, host: &str) -> bool {
111        self.config
112            .blocked_domains
113            .iter()
114            .any(|blocked| host == blocked || host.ends_with(&format!(".{blocked}")))
115    }
116
117    fn is_localhost(&self, host: &str) -> bool {
118        matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]")
119    }
120
121    fn is_private_ip(&self, ip: &IpAddr) -> bool {
122        match ip {
123            IpAddr::V4(ipv4) => {
124                ipv4.is_private()
125                    || ipv4.is_loopback()
126                    || ipv4.is_link_local()
127                    || ipv4.is_unspecified()
128                    || self.is_ipv4_reserved(ipv4)
129            }
130            IpAddr::V6(ipv6) => {
131                ipv6.is_loopback()
132                    || ipv6.is_unspecified()
133                    || self.is_ipv6_link_local(ipv6)
134                    || self.is_ipv6_unique_local(ipv6)
135            }
136        }
137    }
138
139    fn is_ipv4_reserved(&self, ip: &Ipv4Addr) -> bool {
140        // Check for additional reserved ranges
141        let octets = ip.octets();
142
143        // 0.0.0.0/8
144        octets[0] == 0
145            // 10.0.0.0/8
146            || octets[0] == 10
147            // 100.64.0.0/10 (Carrier-grade NAT)
148            || (octets[0] == 100 && (octets[1] & 0b11000000) == 0b01000000)
149            // 169.254.0.0/16 (Link-local)
150            || (octets[0] == 169 && octets[1] == 254)
151            // 172.16.0.0/12
152            || (octets[0] == 172 && (octets[1] >= 16 && octets[1] <= 31))
153            // 192.168.0.0/16
154            || (octets[0] == 192 && octets[1] == 168)
155            // 224.0.0.0/4 (Multicast)
156            || (octets[0] & 0b11110000) == 0b11100000
157            // 240.0.0.0/4 (Reserved)
158            || (octets[0] & 0b11110000) == 0b11110000
159    }
160
161    fn is_ipv6_link_local(&self, ip: &Ipv6Addr) -> bool {
162        // fe80::/10
163        let segments = ip.segments();
164        (segments[0] & 0xffc0) == 0xfe80
165    }
166
167    fn is_ipv6_unique_local(&self, ip: &Ipv6Addr) -> bool {
168        // fc00::/7
169        let segments = ip.segments();
170        (segments[0] & 0xfe00) == 0xfc00
171    }
172}
173
174/// Content size and time limits configuration
175#[derive(Debug, Clone)]
176pub struct ContentLimits {
177    /// Maximum content size in bytes (default: 10MB)
178    pub max_content_size: usize,
179    /// Maximum download time in seconds (default: 30s)
180    pub max_download_time: u64,
181    /// Allowed content types (if not empty, only these are allowed)
182    pub allowed_content_types: HashSet<String>,
183}
184
185impl Default for ContentLimits {
186    fn default() -> Self {
187        let mut allowed_types = HashSet::new();
188        allowed_types.insert("text/html".to_string());
189        allowed_types.insert("application/xhtml+xml".to_string());
190        allowed_types.insert("text/plain".to_string());
191        allowed_types.insert("application/json".to_string());
192
193        Self {
194            max_content_size: 10 * 1024 * 1024, // 10MB
195            max_download_time: 30,
196            allowed_content_types: allowed_types,
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_url_validator_schemes() {
207        let validator = UrlValidator::with_default_config();
208
209        assert!(validator.validate("https://example.com").is_ok());
210        assert!(validator.validate("http://example.com").is_ok());
211        assert!(validator.validate("ftp://example.com").is_err());
212        assert!(validator.validate("file:///etc/passwd").is_err());
213    }
214
215    #[test]
216    fn test_url_validator_localhost() {
217        let validator = UrlValidator::with_default_config();
218
219        assert!(validator.validate("http://localhost").is_err());
220        assert!(validator.validate("http://127.0.0.1").is_err());
221        assert!(validator.validate("http://[::1]").is_err());
222    }
223
224    #[test]
225    fn test_url_validator_private_ips() {
226        let validator = UrlValidator::with_default_config();
227
228        assert!(validator.validate("http://10.0.0.1").is_err());
229        assert!(validator.validate("http://192.168.1.1").is_err());
230        assert!(validator.validate("http://172.16.0.1").is_err());
231        assert!(validator.validate("http://169.254.1.1").is_err());
232    }
233
234    #[test]
235    fn test_url_validator_domain_lists() {
236        let mut config = UrlValidationConfig::default();
237        config.blocked_domains.insert("evil.com".to_string());
238        let validator = UrlValidator::new(config);
239
240        assert!(validator.validate("http://evil.com").is_err());
241        assert!(validator.validate("http://sub.evil.com").is_err());
242        assert!(validator.validate("http://good.com").is_ok());
243    }
244
245    #[test]
246    fn test_url_validator_whitelist() {
247        let mut config = UrlValidationConfig::default();
248        config.allowed_domains.insert("trusted.com".to_string());
249        let validator = UrlValidator::new(config);
250
251        assert!(validator.validate("http://trusted.com").is_ok());
252        assert!(validator.validate("http://sub.trusted.com").is_ok());
253        assert!(validator.validate("http://untrusted.com").is_err());
254    }
255}