Skip to main content

zerobox_network_proxy/
policy.rs

1#[cfg(test)]
2use crate::config::NetworkMode;
3use anyhow::Context;
4use anyhow::Result;
5use anyhow::bail;
6use anyhow::ensure;
7use globset::GlobBuilder;
8use globset::GlobSet;
9use globset::GlobSetBuilder;
10use std::collections::HashSet;
11use std::net::IpAddr;
12use std::net::Ipv4Addr;
13use std::net::Ipv6Addr;
14use url::Host as UrlHost;
15
16/// A normalized host string for policy evaluation.
17#[derive(Clone, Debug, PartialEq, Eq, Hash)]
18pub struct Host(String);
19
20impl Host {
21    pub fn parse(input: &str) -> Result<Self> {
22        let normalized = normalize_host(input);
23        ensure!(!normalized.is_empty(), "host is empty");
24        Ok(Self(normalized))
25    }
26
27    pub fn as_str(&self) -> &str {
28        &self.0
29    }
30}
31
32/// Returns true if the host is a loopback hostname or IP literal.
33pub fn is_loopback_host(host: &Host) -> bool {
34    let host = host.as_str();
35    let host = unscoped_ip_literal(host).unwrap_or(host);
36    if host == "localhost" {
37        return true;
38    }
39    if let Ok(ip) = host.parse::<IpAddr>() {
40        return ip.is_loopback();
41    }
42    false
43}
44
45pub fn is_non_public_ip(ip: IpAddr) -> bool {
46    match ip {
47        IpAddr::V4(ip) => is_non_public_ipv4(ip),
48        IpAddr::V6(ip) => is_non_public_ipv6(ip),
49    }
50}
51
52fn is_non_public_ipv4(ip: Ipv4Addr) -> bool {
53    // Use the standard library classification helpers where possible; they encode the intent more
54    // clearly than hand-rolled range checks. Some non-public ranges (e.g., CGNAT and TEST-NET
55    // blocks) are not covered by stable stdlib helpers yet, so we fall back to CIDR checks.
56    ip.is_loopback()
57        || ip.is_private()
58        || ip.is_link_local()
59        || ip.is_unspecified()
60        || ip.is_multicast()
61        || ip.is_broadcast()
62        || ipv4_in_cidr(ip, [0, 0, 0, 0], /*prefix*/ 8) // "this network" (RFC 1122)
63        || ipv4_in_cidr(ip, [100, 64, 0, 0], /*prefix*/ 10) // CGNAT (RFC 6598)
64        || ipv4_in_cidr(ip, [192, 0, 0, 0], /*prefix*/ 24) // IETF Protocol Assignments (RFC 6890)
65        || ipv4_in_cidr(ip, [192, 0, 2, 0], /*prefix*/ 24) // TEST-NET-1 (RFC 5737)
66        || ipv4_in_cidr(ip, [198, 18, 0, 0], /*prefix*/ 15) // Benchmarking (RFC 2544)
67        || ipv4_in_cidr(ip, [198, 51, 100, 0], /*prefix*/ 24) // TEST-NET-2 (RFC 5737)
68        || ipv4_in_cidr(ip, [203, 0, 113, 0], /*prefix*/ 24) // TEST-NET-3 (RFC 5737)
69        || ipv4_in_cidr(ip, [240, 0, 0, 0], /*prefix*/ 4) // Reserved (RFC 6890)
70}
71
72fn ipv4_in_cidr(ip: Ipv4Addr, base: [u8; 4], prefix: u8) -> bool {
73    let ip = u32::from(ip);
74    let base = u32::from(Ipv4Addr::from(base));
75    let mask = if prefix == 0 {
76        0
77    } else {
78        u32::MAX << (32 - prefix)
79    };
80    (ip & mask) == (base & mask)
81}
82
83fn is_non_public_ipv6(ip: Ipv6Addr) -> bool {
84    if let Some(v4) = ip.to_ipv4() {
85        return is_non_public_ipv4(v4) || ip.is_loopback();
86    }
87    // Treat anything that isn't globally routable as "local" for SSRF prevention. In particular:
88    //  - `::1` loopback
89    //  - `fc00::/7` unique-local (RFC 4193)
90    //  - `fe80::/10` link-local
91    //  - `::` unspecified
92    //  - multicast ranges
93    ip.is_loopback()
94        || ip.is_unspecified()
95        || ip.is_multicast()
96        || ip.is_unique_local()
97        || ip.is_unicast_link_local()
98}
99
100/// Normalize host fragments for policy matching (trim whitespace, strip ports/brackets, lowercase).
101pub fn normalize_host(host: &str) -> String {
102    let host = host.trim();
103    if host.starts_with('[')
104        && let Some(end) = host.find(']')
105    {
106        return normalize_dns_host_or_ip_literal(&host[1..end]);
107    }
108
109    // The proxy stack should typically hand us a host without a port, but be
110    // defensive and strip `:port` when there is exactly one `:`.
111    if host.bytes().filter(|b| *b == b':').count() == 1 {
112        let host = host.split(':').next().unwrap_or_default();
113        return normalize_dns_host_or_ip_literal(host);
114    }
115
116    // Avoid mangling unbracketed IPv6 literals, but strip trailing dots so fully qualified domain
117    // names are treated the same as their dotless variants.
118    normalize_dns_host_or_ip_literal(host)
119}
120
121fn normalize_dns_host_or_ip_literal(host: &str) -> String {
122    let host = host.to_ascii_lowercase();
123    let host = host.trim_end_matches('.');
124    if let Some(ip) = normalize_ip_literal(host) {
125        return ip;
126    }
127    host.to_string()
128}
129
130pub(crate) fn unscoped_ip_literal(host: &str) -> Option<&str> {
131    let (ip, _) = host.split_once('%')?;
132    ip.parse::<IpAddr>().ok()?;
133    Some(ip)
134}
135
136fn normalize_ip_literal(host: &str) -> Option<String> {
137    if host.parse::<IpAddr>().is_ok() {
138        return Some(host.to_string());
139    }
140    for delimiter in ["%25", "%"] {
141        if let Some((ip, scope)) = host.split_once(delimiter)
142            && ip.parse::<IpAddr>().is_ok()
143        {
144            return Some(format!("{ip}%{scope}"));
145        }
146    }
147    None
148}
149
150fn normalize_pattern(pattern: &str) -> String {
151    let pattern = pattern.trim();
152    if pattern == "*" {
153        return "*".to_string();
154    }
155
156    let (prefix, remainder) = if let Some(domain) = pattern.strip_prefix("**.") {
157        ("**.", domain)
158    } else if let Some(domain) = pattern.strip_prefix("*.") {
159        ("*.", domain)
160    } else {
161        ("", pattern)
162    };
163
164    let remainder = normalize_host(remainder);
165    if prefix.is_empty() {
166        remainder
167    } else {
168        format!("{prefix}{remainder}")
169    }
170}
171
172pub(crate) fn is_global_wildcard_domain_pattern(pattern: &str) -> bool {
173    let normalized = normalize_pattern(pattern);
174    expand_domain_pattern(&normalized)
175        .iter()
176        .any(|candidate| candidate == "*")
177}
178
179#[derive(Clone, Copy, PartialEq, Eq)]
180enum GlobalWildcard {
181    Allow,
182    Reject,
183}
184
185pub(crate) fn compile_allowlist_globset(patterns: &[String]) -> Result<GlobSet> {
186    compile_globset_with_policy(patterns, GlobalWildcard::Allow)
187}
188
189pub(crate) fn compile_denylist_globset(patterns: &[String]) -> Result<GlobSet> {
190    compile_globset_with_policy(patterns, GlobalWildcard::Reject)
191}
192
193fn compile_globset_with_policy(
194    patterns: &[String],
195    global_wildcard: GlobalWildcard,
196) -> Result<GlobSet> {
197    let mut builder = GlobSetBuilder::new();
198    let mut seen = HashSet::new();
199    for pattern in patterns {
200        if global_wildcard == GlobalWildcard::Reject && is_global_wildcard_domain_pattern(pattern) {
201            bail!(
202                "unsupported global wildcard domain pattern \"*\"; use exact hosts or scoped wildcards like *.example.com or **.example.com"
203            );
204        }
205        let pattern = normalize_pattern(pattern);
206        // Supported domain patterns:
207        // - "example.com": match the exact host
208        // - "*.example.com": match any subdomain (not the apex)
209        // - "**.example.com": match the apex and any subdomain
210        // - "*": match every host when explicitly enabled for allowlist compilation
211        for candidate in expand_domain_pattern(&pattern) {
212            if !seen.insert(candidate.clone()) {
213                continue;
214            }
215            let glob = GlobBuilder::new(&candidate)
216                .case_insensitive(true)
217                .build()
218                .with_context(|| format!("invalid glob pattern: {candidate}"))?;
219            builder.add(glob);
220        }
221    }
222    Ok(builder.build()?)
223}
224
225#[derive(Debug, Clone)]
226pub(crate) enum DomainPattern {
227    ApexAndSubdomains(String),
228    SubdomainsOnly(String),
229    Exact(String),
230}
231
232impl DomainPattern {
233    /// Parse a policy pattern for constraint comparisons.
234    ///
235    /// Validation of glob syntax happens when building the globset; here we only
236    /// decode the wildcard prefixes to keep constraint checks lightweight.
237    pub(crate) fn parse(input: &str) -> Self {
238        let input = input.trim();
239        if input.is_empty() {
240            return Self::Exact(String::new());
241        }
242        if let Some(domain) = input.strip_prefix("**.") {
243            Self::parse_domain(domain, Self::ApexAndSubdomains)
244        } else if let Some(domain) = input.strip_prefix("*.") {
245            Self::parse_domain(domain, Self::SubdomainsOnly)
246        } else {
247            Self::Exact(input.to_string())
248        }
249    }
250
251    /// Parse a policy pattern for constraint comparisons, validating domain parts with `url`.
252    pub(crate) fn parse_for_constraints(input: &str) -> Self {
253        let input = input.trim();
254        if input.is_empty() {
255            return Self::Exact(String::new());
256        }
257        if let Some(domain) = input.strip_prefix("**.") {
258            return Self::ApexAndSubdomains(parse_domain_for_constraints(domain));
259        }
260        if let Some(domain) = input.strip_prefix("*.") {
261            return Self::SubdomainsOnly(parse_domain_for_constraints(domain));
262        }
263        Self::Exact(parse_domain_for_constraints(input))
264    }
265
266    fn parse_domain(domain: &str, build: impl FnOnce(String) -> Self) -> Self {
267        let domain = domain.trim();
268        if domain.is_empty() {
269            return Self::Exact(String::new());
270        }
271        build(domain.to_string())
272    }
273
274    pub(crate) fn allows(&self, candidate: &DomainPattern) -> bool {
275        match self {
276            DomainPattern::Exact(domain) => match candidate {
277                DomainPattern::Exact(candidate) => domain_eq(candidate, domain),
278                _ => false,
279            },
280            DomainPattern::SubdomainsOnly(domain) => match candidate {
281                DomainPattern::Exact(candidate) => is_strict_subdomain(candidate, domain),
282                DomainPattern::SubdomainsOnly(candidate) => {
283                    is_subdomain_or_equal(candidate, domain)
284                }
285                DomainPattern::ApexAndSubdomains(candidate) => {
286                    is_strict_subdomain(candidate, domain)
287                }
288            },
289            DomainPattern::ApexAndSubdomains(domain) => match candidate {
290                DomainPattern::Exact(candidate) => is_subdomain_or_equal(candidate, domain),
291                DomainPattern::SubdomainsOnly(candidate) => {
292                    is_subdomain_or_equal(candidate, domain)
293                }
294                DomainPattern::ApexAndSubdomains(candidate) => {
295                    is_subdomain_or_equal(candidate, domain)
296                }
297            },
298        }
299    }
300}
301
302fn parse_domain_for_constraints(domain: &str) -> String {
303    let domain = domain.trim().trim_end_matches('.');
304    if domain.is_empty() {
305        return String::new();
306    }
307    let host = if domain.starts_with('[') && domain.ends_with(']') {
308        &domain[1..domain.len().saturating_sub(1)]
309    } else {
310        domain
311    };
312    if host.contains('*') || host.contains('?') || host.contains('%') {
313        return domain.to_string();
314    }
315    match UrlHost::parse(host) {
316        Ok(host) => host.to_string(),
317        Err(_) => String::new(),
318    }
319}
320
321fn expand_domain_pattern(pattern: &str) -> Vec<String> {
322    match DomainPattern::parse(pattern) {
323        DomainPattern::Exact(domain) => vec![domain],
324        DomainPattern::SubdomainsOnly(domain) => {
325            vec![format!("?*.{domain}")]
326        }
327        DomainPattern::ApexAndSubdomains(domain) => {
328            vec![domain.clone(), format!("?*.{domain}")]
329        }
330    }
331}
332
333fn normalize_domain(domain: &str) -> String {
334    domain.trim_end_matches('.').to_ascii_lowercase()
335}
336
337fn domain_eq(left: &str, right: &str) -> bool {
338    normalize_domain(left) == normalize_domain(right)
339}
340
341fn is_subdomain_or_equal(child: &str, parent: &str) -> bool {
342    let child = normalize_domain(child);
343    let parent = normalize_domain(parent);
344    if child == parent {
345        return true;
346    }
347    child.ends_with(&format!(".{parent}"))
348}
349
350fn is_strict_subdomain(child: &str, parent: &str) -> bool {
351    let child = normalize_domain(child);
352    let parent = normalize_domain(parent);
353    child != parent && child.ends_with(&format!(".{parent}"))
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    use pretty_assertions::assert_eq;
361
362    #[test]
363    fn method_allowed_full_allows_everything() {
364        assert!(NetworkMode::Full.allows_method("GET"));
365        assert!(NetworkMode::Full.allows_method("POST"));
366        assert!(NetworkMode::Full.allows_method("CONNECT"));
367    }
368
369    #[test]
370    fn method_allowed_limited_allows_only_safe_methods() {
371        assert!(NetworkMode::Limited.allows_method("GET"));
372        assert!(NetworkMode::Limited.allows_method("HEAD"));
373        assert!(NetworkMode::Limited.allows_method("OPTIONS"));
374        assert!(!NetworkMode::Limited.allows_method("POST"));
375        assert!(!NetworkMode::Limited.allows_method("CONNECT"));
376    }
377
378    #[test]
379    fn compile_globset_normalizes_trailing_dots() {
380        let set = compile_denylist_globset(&["Example.COM.".to_string()]).unwrap();
381
382        assert_eq!(true, set.is_match("example.com"));
383        assert_eq!(false, set.is_match("api.example.com"));
384    }
385
386    #[test]
387    fn compile_globset_normalizes_wildcards() {
388        let set = compile_denylist_globset(&["*.Example.COM.".to_string()]).unwrap();
389
390        assert_eq!(true, set.is_match("api.example.com"));
391        assert_eq!(false, set.is_match("example.com"));
392    }
393
394    #[test]
395    fn compile_globset_supports_mid_label_wildcards() {
396        let set = compile_denylist_globset(&["region*.v2.argotunnel.com".to_string()]).unwrap();
397
398        assert_eq!(true, set.is_match("region1.v2.argotunnel.com"));
399        assert_eq!(true, set.is_match("region.v2.argotunnel.com"));
400        assert_eq!(false, set.is_match("xregion1.v2.argotunnel.com"));
401        assert_eq!(false, set.is_match("foo.region1.v2.argotunnel.com"));
402    }
403
404    #[test]
405    fn compile_globset_normalizes_apex_and_subdomains() {
406        let set = compile_denylist_globset(&["**.Example.COM.".to_string()]).unwrap();
407
408        assert_eq!(true, set.is_match("example.com"));
409        assert_eq!(true, set.is_match("api.example.com"));
410    }
411
412    #[test]
413    fn compile_globset_normalizes_bracketed_ipv6_literals() {
414        let set = compile_denylist_globset(&["[::1]".to_string()]).unwrap();
415
416        assert_eq!(true, set.is_match("::1"));
417    }
418
419    #[test]
420    fn compile_globset_preserves_scoped_ipv6_literals() {
421        let set = compile_denylist_globset(&["[fe80::1%25lo0]".to_string()]).unwrap();
422
423        assert_eq!(true, set.is_match("fe80::1%lo0"));
424        assert_eq!(false, set.is_match("fe80::1%lo1"));
425        assert_eq!(false, set.is_match("fe80::1"));
426    }
427
428    #[test]
429    fn is_loopback_host_handles_localhost_variants() {
430        assert!(is_loopback_host(&Host::parse("localhost").unwrap()));
431        assert!(is_loopback_host(&Host::parse("localhost.").unwrap()));
432        assert!(is_loopback_host(&Host::parse("LOCALHOST").unwrap()));
433        assert!(!is_loopback_host(&Host::parse("notlocalhost").unwrap()));
434    }
435
436    #[test]
437    fn is_loopback_host_handles_ip_literals() {
438        assert!(is_loopback_host(&Host::parse("127.0.0.1").unwrap()));
439        assert!(is_loopback_host(&Host::parse("::1").unwrap()));
440        assert!(!is_loopback_host(&Host::parse("1.2.3.4").unwrap()));
441    }
442
443    #[test]
444    fn is_non_public_ip_rejects_private_and_loopback_ranges() {
445        assert!(is_non_public_ip("127.0.0.1".parse().unwrap()));
446        assert!(is_non_public_ip("10.0.0.1".parse().unwrap()));
447        assert!(is_non_public_ip("192.168.0.1".parse().unwrap()));
448        assert!(is_non_public_ip("100.64.0.1".parse().unwrap()));
449        assert!(is_non_public_ip("192.0.0.1".parse().unwrap()));
450        assert!(is_non_public_ip("192.0.2.1".parse().unwrap()));
451        assert!(is_non_public_ip("198.18.0.1".parse().unwrap()));
452        assert!(is_non_public_ip("198.51.100.1".parse().unwrap()));
453        assert!(is_non_public_ip("203.0.113.1".parse().unwrap()));
454        assert!(is_non_public_ip("240.0.0.1".parse().unwrap()));
455        assert!(is_non_public_ip("0.1.2.3".parse().unwrap()));
456        assert!(!is_non_public_ip("8.8.8.8".parse().unwrap()));
457
458        assert!(is_non_public_ip("::ffff:127.0.0.1".parse().unwrap()));
459        assert!(is_non_public_ip("::ffff:10.0.0.1".parse().unwrap()));
460        assert!(!is_non_public_ip("::ffff:8.8.8.8".parse().unwrap()));
461
462        assert!(is_non_public_ip("::1".parse().unwrap()));
463        assert!(is_non_public_ip("fe80::1".parse().unwrap()));
464        assert!(is_non_public_ip("fc00::1".parse().unwrap()));
465    }
466
467    #[test]
468    fn normalize_host_lowercases_and_trims() {
469        assert_eq!(normalize_host("  ExAmPlE.CoM  "), "example.com");
470    }
471
472    #[test]
473    fn normalize_host_strips_port_for_host_port() {
474        assert_eq!(normalize_host("example.com:1234"), "example.com");
475    }
476
477    #[test]
478    fn normalize_host_preserves_unbracketed_ipv6() {
479        assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
480    }
481
482    #[test]
483    fn normalize_host_strips_trailing_dot() {
484        assert_eq!(normalize_host("example.com."), "example.com");
485        assert_eq!(normalize_host("ExAmPlE.CoM."), "example.com");
486    }
487
488    #[test]
489    fn normalize_host_strips_trailing_dot_with_port() {
490        assert_eq!(normalize_host("example.com.:443"), "example.com");
491    }
492
493    #[test]
494    fn normalize_host_strips_brackets_for_ipv6() {
495        assert_eq!(normalize_host("[::1]"), "::1");
496        assert_eq!(normalize_host("[::1]:443"), "::1");
497    }
498
499    #[test]
500    fn normalize_host_preserves_ipv6_scope_ids() {
501        assert_eq!(normalize_host("fe80::1%lo0"), "fe80::1%lo0");
502        assert_eq!(normalize_host("[fe80::1%lo0]"), "fe80::1%lo0");
503        assert_eq!(normalize_host("[fe80::1%25lo0]"), "fe80::1%lo0");
504    }
505}