1#[cfg(test)]
2use crate::config::NetworkMode;
3use anyhow::Context;
4use anyhow::Result;
5use anyhow::bail;
6use anyhow::ensure;
7use globset::GlobBuilder;
8use globset::GlobSet;
9use globset::GlobSetBuilder;
10use std::collections::HashSet;
11use std::net::IpAddr;
12use std::net::Ipv4Addr;
13use std::net::Ipv6Addr;
14use url::Host as UrlHost;
15
16#[derive(Clone, Debug, PartialEq, Eq, Hash)]
18pub struct Host(String);
19
20impl Host {
21 pub fn parse(input: &str) -> Result<Self> {
22 let normalized = normalize_host(input);
23 ensure!(!normalized.is_empty(), "host is empty");
24 Ok(Self(normalized))
25 }
26
27 pub fn as_str(&self) -> &str {
28 &self.0
29 }
30}
31
32pub fn is_loopback_host(host: &Host) -> bool {
34 let host = host.as_str();
35 let host = host.split_once('%').map(|(ip, _)| ip).unwrap_or(host);
36 if host == "localhost" {
37 return true;
38 }
39 if let Ok(ip) = host.parse::<IpAddr>() {
40 return ip.is_loopback();
41 }
42 false
43}
44
45pub fn is_non_public_ip(ip: IpAddr) -> bool {
46 match ip {
47 IpAddr::V4(ip) => is_non_public_ipv4(ip),
48 IpAddr::V6(ip) => is_non_public_ipv6(ip),
49 }
50}
51
52fn is_non_public_ipv4(ip: Ipv4Addr) -> bool {
53 ip.is_loopback()
57 || ip.is_private()
58 || ip.is_link_local()
59 || ip.is_unspecified()
60 || ip.is_multicast()
61 || ip.is_broadcast()
62 || ipv4_in_cidr(ip, [0, 0, 0, 0], 8) || ipv4_in_cidr(ip, [100, 64, 0, 0], 10) || ipv4_in_cidr(ip, [192, 0, 0, 0], 24) || ipv4_in_cidr(ip, [192, 0, 2, 0], 24) || ipv4_in_cidr(ip, [198, 18, 0, 0], 15) || ipv4_in_cidr(ip, [198, 51, 100, 0], 24) || ipv4_in_cidr(ip, [203, 0, 113, 0], 24) || ipv4_in_cidr(ip, [240, 0, 0, 0], 4) }
71
72fn ipv4_in_cidr(ip: Ipv4Addr, base: [u8; 4], prefix: u8) -> bool {
73 let ip = u32::from(ip);
74 let base = u32::from(Ipv4Addr::from(base));
75 let mask = if prefix == 0 {
76 0
77 } else {
78 u32::MAX << (32 - prefix)
79 };
80 (ip & mask) == (base & mask)
81}
82
83fn is_non_public_ipv6(ip: Ipv6Addr) -> bool {
84 if let Some(v4) = ip.to_ipv4() {
85 return is_non_public_ipv4(v4) || ip.is_loopback();
86 }
87 ip.is_loopback()
94 || ip.is_unspecified()
95 || ip.is_multicast()
96 || ip.is_unique_local()
97 || ip.is_unicast_link_local()
98}
99
100pub fn normalize_host(host: &str) -> String {
102 let host = host.trim();
103 if host.starts_with('[')
104 && let Some(end) = host.find(']')
105 {
106 return normalize_dns_host(&host[1..end]);
107 }
108
109 if host.bytes().filter(|b| *b == b':').count() == 1 {
112 let host = host.split(':').next().unwrap_or_default();
113 return normalize_dns_host(host);
114 }
115
116 normalize_dns_host(host)
119}
120
121fn normalize_dns_host(host: &str) -> String {
122 let host = host.to_ascii_lowercase();
123 host.trim_end_matches('.').to_string()
124}
125
126fn normalize_pattern(pattern: &str) -> String {
127 let pattern = pattern.trim();
128 if pattern == "*" {
129 return "*".to_string();
130 }
131
132 let (prefix, remainder) = if let Some(domain) = pattern.strip_prefix("**.") {
133 ("**.", domain)
134 } else if let Some(domain) = pattern.strip_prefix("*.") {
135 ("*.", domain)
136 } else {
137 ("", pattern)
138 };
139
140 let remainder = normalize_host(remainder);
141 if prefix.is_empty() {
142 remainder
143 } else {
144 format!("{prefix}{remainder}")
145 }
146}
147
148pub(crate) fn is_global_wildcard_domain_pattern(pattern: &str) -> bool {
149 let normalized = normalize_pattern(pattern);
150 expand_domain_pattern(&normalized)
151 .iter()
152 .any(|candidate| candidate == "*")
153}
154
155#[derive(Clone, Copy, PartialEq, Eq)]
156enum GlobalWildcard {
157 Allow,
158 Reject,
159}
160
161pub(crate) fn compile_allowlist_globset(patterns: &[String]) -> Result<GlobSet> {
162 compile_globset_with_policy(patterns, GlobalWildcard::Allow)
163}
164
165pub(crate) fn compile_denylist_globset(patterns: &[String]) -> Result<GlobSet> {
166 compile_globset_with_policy(patterns, GlobalWildcard::Reject)
167}
168
169fn compile_globset_with_policy(
170 patterns: &[String],
171 global_wildcard: GlobalWildcard,
172) -> Result<GlobSet> {
173 let mut builder = GlobSetBuilder::new();
174 let mut seen = HashSet::new();
175 for pattern in patterns {
176 if global_wildcard == GlobalWildcard::Reject && is_global_wildcard_domain_pattern(pattern) {
177 bail!(
178 "unsupported global wildcard domain pattern \"*\"; use exact hosts or scoped wildcards like *.example.com or **.example.com"
179 );
180 }
181 let pattern = normalize_pattern(pattern);
182 for candidate in expand_domain_pattern(&pattern) {
188 if !seen.insert(candidate.clone()) {
189 continue;
190 }
191 let glob = GlobBuilder::new(&candidate)
192 .case_insensitive(true)
193 .build()
194 .with_context(|| format!("invalid glob pattern: {candidate}"))?;
195 builder.add(glob);
196 }
197 }
198 Ok(builder.build()?)
199}
200
201#[derive(Debug, Clone)]
202pub(crate) enum DomainPattern {
203 ApexAndSubdomains(String),
204 SubdomainsOnly(String),
205 Exact(String),
206}
207
208impl DomainPattern {
209 pub(crate) fn parse(input: &str) -> Self {
214 let input = input.trim();
215 if input.is_empty() {
216 return Self::Exact(String::new());
217 }
218 if let Some(domain) = input.strip_prefix("**.") {
219 Self::parse_domain(domain, Self::ApexAndSubdomains)
220 } else if let Some(domain) = input.strip_prefix("*.") {
221 Self::parse_domain(domain, Self::SubdomainsOnly)
222 } else {
223 Self::Exact(input.to_string())
224 }
225 }
226
227 pub(crate) fn parse_for_constraints(input: &str) -> Self {
229 let input = input.trim();
230 if input.is_empty() {
231 return Self::Exact(String::new());
232 }
233 if let Some(domain) = input.strip_prefix("**.") {
234 return Self::ApexAndSubdomains(parse_domain_for_constraints(domain));
235 }
236 if let Some(domain) = input.strip_prefix("*.") {
237 return Self::SubdomainsOnly(parse_domain_for_constraints(domain));
238 }
239 Self::Exact(parse_domain_for_constraints(input))
240 }
241
242 fn parse_domain(domain: &str, build: impl FnOnce(String) -> Self) -> Self {
243 let domain = domain.trim();
244 if domain.is_empty() {
245 return Self::Exact(String::new());
246 }
247 build(domain.to_string())
248 }
249
250 pub(crate) fn allows(&self, candidate: &DomainPattern) -> bool {
251 match self {
252 DomainPattern::Exact(domain) => match candidate {
253 DomainPattern::Exact(candidate) => domain_eq(candidate, domain),
254 _ => false,
255 },
256 DomainPattern::SubdomainsOnly(domain) => match candidate {
257 DomainPattern::Exact(candidate) => is_strict_subdomain(candidate, domain),
258 DomainPattern::SubdomainsOnly(candidate) => {
259 is_subdomain_or_equal(candidate, domain)
260 }
261 DomainPattern::ApexAndSubdomains(candidate) => {
262 is_strict_subdomain(candidate, domain)
263 }
264 },
265 DomainPattern::ApexAndSubdomains(domain) => match candidate {
266 DomainPattern::Exact(candidate) => is_subdomain_or_equal(candidate, domain),
267 DomainPattern::SubdomainsOnly(candidate) => {
268 is_subdomain_or_equal(candidate, domain)
269 }
270 DomainPattern::ApexAndSubdomains(candidate) => {
271 is_subdomain_or_equal(candidate, domain)
272 }
273 },
274 }
275 }
276}
277
278fn parse_domain_for_constraints(domain: &str) -> String {
279 let domain = domain.trim().trim_end_matches('.');
280 if domain.is_empty() {
281 return String::new();
282 }
283 let host = if domain.starts_with('[') && domain.ends_with(']') {
284 &domain[1..domain.len().saturating_sub(1)]
285 } else {
286 domain
287 };
288 if host.contains('*') || host.contains('?') || host.contains('%') {
289 return domain.to_string();
290 }
291 match UrlHost::parse(host) {
292 Ok(host) => host.to_string(),
293 Err(_) => String::new(),
294 }
295}
296
297fn expand_domain_pattern(pattern: &str) -> Vec<String> {
298 match DomainPattern::parse(pattern) {
299 DomainPattern::Exact(domain) => vec![domain],
300 DomainPattern::SubdomainsOnly(domain) => {
301 vec![format!("?*.{domain}")]
302 }
303 DomainPattern::ApexAndSubdomains(domain) => {
304 vec![domain.clone(), format!("?*.{domain}")]
305 }
306 }
307}
308
309fn normalize_domain(domain: &str) -> String {
310 domain.trim_end_matches('.').to_ascii_lowercase()
311}
312
313fn domain_eq(left: &str, right: &str) -> bool {
314 normalize_domain(left) == normalize_domain(right)
315}
316
317fn is_subdomain_or_equal(child: &str, parent: &str) -> bool {
318 let child = normalize_domain(child);
319 let parent = normalize_domain(parent);
320 if child == parent {
321 return true;
322 }
323 child.ends_with(&format!(".{parent}"))
324}
325
326fn is_strict_subdomain(child: &str, parent: &str) -> bool {
327 let child = normalize_domain(child);
328 let parent = normalize_domain(parent);
329 child != parent && child.ends_with(&format!(".{parent}"))
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 use pretty_assertions::assert_eq;
337
338 #[test]
339 fn method_allowed_full_allows_everything() {
340 assert!(NetworkMode::Full.allows_method("GET"));
341 assert!(NetworkMode::Full.allows_method("POST"));
342 assert!(NetworkMode::Full.allows_method("CONNECT"));
343 }
344
345 #[test]
346 fn method_allowed_limited_allows_only_safe_methods() {
347 assert!(NetworkMode::Limited.allows_method("GET"));
348 assert!(NetworkMode::Limited.allows_method("HEAD"));
349 assert!(NetworkMode::Limited.allows_method("OPTIONS"));
350 assert!(!NetworkMode::Limited.allows_method("POST"));
351 assert!(!NetworkMode::Limited.allows_method("CONNECT"));
352 }
353
354 #[test]
355 fn compile_globset_normalizes_trailing_dots() {
356 let set = compile_denylist_globset(&["Example.COM.".to_string()]).unwrap();
357
358 assert_eq!(true, set.is_match("example.com"));
359 assert_eq!(false, set.is_match("api.example.com"));
360 }
361
362 #[test]
363 fn compile_globset_normalizes_wildcards() {
364 let set = compile_denylist_globset(&["*.Example.COM.".to_string()]).unwrap();
365
366 assert_eq!(true, set.is_match("api.example.com"));
367 assert_eq!(false, set.is_match("example.com"));
368 }
369
370 #[test]
371 fn compile_globset_supports_mid_label_wildcards() {
372 let set = compile_denylist_globset(&["region*.v2.argotunnel.com".to_string()]).unwrap();
373
374 assert_eq!(true, set.is_match("region1.v2.argotunnel.com"));
375 assert_eq!(true, set.is_match("region.v2.argotunnel.com"));
376 assert_eq!(false, set.is_match("xregion1.v2.argotunnel.com"));
377 assert_eq!(false, set.is_match("foo.region1.v2.argotunnel.com"));
378 }
379
380 #[test]
381 fn compile_globset_normalizes_apex_and_subdomains() {
382 let set = compile_denylist_globset(&["**.Example.COM.".to_string()]).unwrap();
383
384 assert_eq!(true, set.is_match("example.com"));
385 assert_eq!(true, set.is_match("api.example.com"));
386 }
387
388 #[test]
389 fn compile_globset_normalizes_bracketed_ipv6_literals() {
390 let set = compile_denylist_globset(&["[::1]".to_string()]).unwrap();
391
392 assert_eq!(true, set.is_match("::1"));
393 }
394
395 #[test]
396 fn is_loopback_host_handles_localhost_variants() {
397 assert!(is_loopback_host(&Host::parse("localhost").unwrap()));
398 assert!(is_loopback_host(&Host::parse("localhost.").unwrap()));
399 assert!(is_loopback_host(&Host::parse("LOCALHOST").unwrap()));
400 assert!(!is_loopback_host(&Host::parse("notlocalhost").unwrap()));
401 }
402
403 #[test]
404 fn is_loopback_host_handles_ip_literals() {
405 assert!(is_loopback_host(&Host::parse("127.0.0.1").unwrap()));
406 assert!(is_loopback_host(&Host::parse("::1").unwrap()));
407 assert!(!is_loopback_host(&Host::parse("1.2.3.4").unwrap()));
408 }
409
410 #[test]
411 fn is_non_public_ip_rejects_private_and_loopback_ranges() {
412 assert!(is_non_public_ip("127.0.0.1".parse().unwrap()));
413 assert!(is_non_public_ip("10.0.0.1".parse().unwrap()));
414 assert!(is_non_public_ip("192.168.0.1".parse().unwrap()));
415 assert!(is_non_public_ip("100.64.0.1".parse().unwrap()));
416 assert!(is_non_public_ip("192.0.0.1".parse().unwrap()));
417 assert!(is_non_public_ip("192.0.2.1".parse().unwrap()));
418 assert!(is_non_public_ip("198.18.0.1".parse().unwrap()));
419 assert!(is_non_public_ip("198.51.100.1".parse().unwrap()));
420 assert!(is_non_public_ip("203.0.113.1".parse().unwrap()));
421 assert!(is_non_public_ip("240.0.0.1".parse().unwrap()));
422 assert!(is_non_public_ip("0.1.2.3".parse().unwrap()));
423 assert!(!is_non_public_ip("8.8.8.8".parse().unwrap()));
424
425 assert!(is_non_public_ip("::ffff:127.0.0.1".parse().unwrap()));
426 assert!(is_non_public_ip("::ffff:10.0.0.1".parse().unwrap()));
427 assert!(!is_non_public_ip("::ffff:8.8.8.8".parse().unwrap()));
428
429 assert!(is_non_public_ip("::1".parse().unwrap()));
430 assert!(is_non_public_ip("fe80::1".parse().unwrap()));
431 assert!(is_non_public_ip("fc00::1".parse().unwrap()));
432 }
433
434 #[test]
435 fn normalize_host_lowercases_and_trims() {
436 assert_eq!(normalize_host(" ExAmPlE.CoM "), "example.com");
437 }
438
439 #[test]
440 fn normalize_host_strips_port_for_host_port() {
441 assert_eq!(normalize_host("example.com:1234"), "example.com");
442 }
443
444 #[test]
445 fn normalize_host_preserves_unbracketed_ipv6() {
446 assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
447 }
448
449 #[test]
450 fn normalize_host_strips_trailing_dot() {
451 assert_eq!(normalize_host("example.com."), "example.com");
452 assert_eq!(normalize_host("ExAmPlE.CoM."), "example.com");
453 }
454
455 #[test]
456 fn normalize_host_strips_trailing_dot_with_port() {
457 assert_eq!(normalize_host("example.com.:443"), "example.com");
458 }
459
460 #[test]
461 fn normalize_host_strips_brackets_for_ipv6() {
462 assert_eq!(normalize_host("[::1]"), "::1");
463 assert_eq!(normalize_host("[::1]:443"), "::1");
464 }
465}