1use crate::error::PreviewError;
2use std::collections::HashSet;
3use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4use url::Url;
5
6#[derive(Debug, Clone)]
8pub struct UrlValidationConfig {
9 pub allowed_schemes: HashSet<String>,
11 pub block_private_ips: bool,
13 pub block_localhost: bool,
15 pub blocked_domains: HashSet<String>,
17 pub allowed_domains: HashSet<String>,
19 pub max_redirects: usize,
21}
22
23impl Default for UrlValidationConfig {
24 fn default() -> Self {
25 let mut allowed_schemes = HashSet::new();
26 allowed_schemes.insert("http".to_string());
27 allowed_schemes.insert("https".to_string());
28
29 Self {
30 allowed_schemes,
31 block_private_ips: true,
32 block_localhost: true,
33 blocked_domains: HashSet::new(),
34 allowed_domains: HashSet::new(),
35 max_redirects: 10,
36 }
37 }
38}
39
40#[derive(Clone)]
42pub struct UrlValidator {
43 config: UrlValidationConfig,
44}
45
46impl UrlValidator {
47 pub fn new(config: UrlValidationConfig) -> Self {
48 Self { config }
49 }
50
51 pub fn with_default_config() -> Self {
52 Self::new(UrlValidationConfig::default())
53 }
54
55 pub fn validate(&self, url_str: &str) -> Result<Url, PreviewError> {
57 let url = Url::parse(url_str).map_err(PreviewError::UrlParseError)?;
59
60 if !self.config.allowed_schemes.contains(url.scheme()) {
62 return Err(PreviewError::InvalidUrlScheme(url.scheme().to_string()));
63 }
64
65 let host = url
67 .host_str()
68 .ok_or_else(|| PreviewError::InvalidUrl("No host in URL".to_string()))?;
69
70 if !self.config.allowed_domains.is_empty() {
72 if !self.is_domain_allowed(host) {
73 return Err(PreviewError::DomainNotAllowed(host.to_string()));
74 }
75 } else if self.is_domain_blocked(host) {
76 return Err(PreviewError::DomainBlocked(host.to_string()));
77 }
78
79 if self.config.block_localhost && self.is_localhost(host) {
81 return Err(PreviewError::LocalhostBlocked);
82 }
83
84 if self.config.block_private_ips {
86 let ip_str = if host.starts_with('[') && host.ends_with(']') {
88 &host[1..host.len() - 1]
89 } else {
90 host
91 };
92
93 if let Ok(ip) = ip_str.parse::<IpAddr>() {
94 if self.is_private_ip(&ip) {
95 return Err(PreviewError::PrivateIpBlocked(ip.to_string()));
96 }
97 }
98 }
99
100 Ok(url)
101 }
102
103 fn is_domain_allowed(&self, host: &str) -> bool {
104 self.config
105 .allowed_domains
106 .iter()
107 .any(|allowed| host == allowed || host.ends_with(&format!(".{allowed}")))
108 }
109
110 fn is_domain_blocked(&self, host: &str) -> bool {
111 self.config
112 .blocked_domains
113 .iter()
114 .any(|blocked| host == blocked || host.ends_with(&format!(".{blocked}")))
115 }
116
117 fn is_localhost(&self, host: &str) -> bool {
118 matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]")
119 }
120
121 fn is_private_ip(&self, ip: &IpAddr) -> bool {
122 match ip {
123 IpAddr::V4(ipv4) => {
124 ipv4.is_private()
125 || ipv4.is_loopback()
126 || ipv4.is_link_local()
127 || ipv4.is_unspecified()
128 || self.is_ipv4_reserved(ipv4)
129 }
130 IpAddr::V6(ipv6) => {
131 ipv6.is_loopback()
132 || ipv6.is_unspecified()
133 || self.is_ipv6_link_local(ipv6)
134 || self.is_ipv6_unique_local(ipv6)
135 }
136 }
137 }
138
139 fn is_ipv4_reserved(&self, ip: &Ipv4Addr) -> bool {
140 let octets = ip.octets();
142
143 octets[0] == 0
145 || octets[0] == 10
147 || (octets[0] == 100 && (octets[1] & 0b11000000) == 0b01000000)
149 || (octets[0] == 169 && octets[1] == 254)
151 || (octets[0] == 172 && (octets[1] >= 16 && octets[1] <= 31))
153 || (octets[0] == 192 && octets[1] == 168)
155 || (octets[0] & 0b11110000) == 0b11100000
157 || (octets[0] & 0b11110000) == 0b11110000
159 }
160
161 fn is_ipv6_link_local(&self, ip: &Ipv6Addr) -> bool {
162 let segments = ip.segments();
164 (segments[0] & 0xffc0) == 0xfe80
165 }
166
167 fn is_ipv6_unique_local(&self, ip: &Ipv6Addr) -> bool {
168 let segments = ip.segments();
170 (segments[0] & 0xfe00) == 0xfc00
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct ContentLimits {
177 pub max_content_size: usize,
179 pub max_download_time: u64,
181 pub allowed_content_types: HashSet<String>,
183}
184
185impl Default for ContentLimits {
186 fn default() -> Self {
187 let mut allowed_types = HashSet::new();
188 allowed_types.insert("text/html".to_string());
189 allowed_types.insert("application/xhtml+xml".to_string());
190 allowed_types.insert("text/plain".to_string());
191 allowed_types.insert("application/json".to_string());
192
193 Self {
194 max_content_size: 10 * 1024 * 1024, max_download_time: 30,
196 allowed_content_types: allowed_types,
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_url_validator_schemes() {
207 let validator = UrlValidator::with_default_config();
208
209 assert!(validator.validate("https://example.com").is_ok());
210 assert!(validator.validate("http://example.com").is_ok());
211 assert!(validator.validate("ftp://example.com").is_err());
212 assert!(validator.validate("file:///etc/passwd").is_err());
213 }
214
215 #[test]
216 fn test_url_validator_localhost() {
217 let validator = UrlValidator::with_default_config();
218
219 assert!(validator.validate("http://localhost").is_err());
220 assert!(validator.validate("http://127.0.0.1").is_err());
221 assert!(validator.validate("http://[::1]").is_err());
222 }
223
224 #[test]
225 fn test_url_validator_private_ips() {
226 let validator = UrlValidator::with_default_config();
227
228 assert!(validator.validate("http://10.0.0.1").is_err());
229 assert!(validator.validate("http://192.168.1.1").is_err());
230 assert!(validator.validate("http://172.16.0.1").is_err());
231 assert!(validator.validate("http://169.254.1.1").is_err());
232 }
233
234 #[test]
235 fn test_url_validator_domain_lists() {
236 let mut config = UrlValidationConfig::default();
237 config.blocked_domains.insert("evil.com".to_string());
238 let validator = UrlValidator::new(config);
239
240 assert!(validator.validate("http://evil.com").is_err());
241 assert!(validator.validate("http://sub.evil.com").is_err());
242 assert!(validator.validate("http://good.com").is_ok());
243 }
244
245 #[test]
246 fn test_url_validator_whitelist() {
247 let mut config = UrlValidationConfig::default();
248 config.allowed_domains.insert("trusted.com".to_string());
249 let validator = UrlValidator::new(config);
250
251 assert!(validator.validate("http://trusted.com").is_ok());
252 assert!(validator.validate("http://sub.trusted.com").is_ok());
253 assert!(validator.validate("http://untrusted.com").is_err());
254 }
255}