Skip to main content

zerobox_network_proxy/
policy.rs

1#[cfg(test)]
2use crate::config::NetworkMode;
3use anyhow::Context;
4use anyhow::Result;
5use anyhow::bail;
6use anyhow::ensure;
7use globset::GlobBuilder;
8use globset::GlobSet;
9use globset::GlobSetBuilder;
10use std::collections::HashSet;
11use std::net::IpAddr;
12use std::net::Ipv4Addr;
13use std::net::Ipv6Addr;
14use url::Host as UrlHost;
15
16/// A normalized host string for policy evaluation.
17#[derive(Clone, Debug, PartialEq, Eq, Hash)]
18pub struct Host(String);
19
20impl Host {
21    pub fn parse(input: &str) -> Result<Self> {
22        let normalized = normalize_host(input);
23        ensure!(!normalized.is_empty(), "host is empty");
24        Ok(Self(normalized))
25    }
26
27    pub fn as_str(&self) -> &str {
28        &self.0
29    }
30}
31
32/// Returns true if the host is a loopback hostname or IP literal.
33pub fn is_loopback_host(host: &Host) -> bool {
34    let host = host.as_str();
35    let host = host.split_once('%').map(|(ip, _)| ip).unwrap_or(host);
36    if host == "localhost" {
37        return true;
38    }
39    if let Ok(ip) = host.parse::<IpAddr>() {
40        return ip.is_loopback();
41    }
42    false
43}
44
45pub fn is_non_public_ip(ip: IpAddr) -> bool {
46    match ip {
47        IpAddr::V4(ip) => is_non_public_ipv4(ip),
48        IpAddr::V6(ip) => is_non_public_ipv6(ip),
49    }
50}
51
52fn is_non_public_ipv4(ip: Ipv4Addr) -> bool {
53    // Use the standard library classification helpers where possible; they encode the intent more
54    // clearly than hand-rolled range checks. Some non-public ranges (e.g., CGNAT and TEST-NET
55    // blocks) are not covered by stable stdlib helpers yet, so we fall back to CIDR checks.
56    ip.is_loopback()
57        || ip.is_private()
58        || ip.is_link_local()
59        || ip.is_unspecified()
60        || ip.is_multicast()
61        || ip.is_broadcast()
62        || ipv4_in_cidr(ip, [0, 0, 0, 0], /*prefix*/ 8) // "this network" (RFC 1122)
63        || ipv4_in_cidr(ip, [100, 64, 0, 0], /*prefix*/ 10) // CGNAT (RFC 6598)
64        || ipv4_in_cidr(ip, [192, 0, 0, 0], /*prefix*/ 24) // IETF Protocol Assignments (RFC 6890)
65        || ipv4_in_cidr(ip, [192, 0, 2, 0], /*prefix*/ 24) // TEST-NET-1 (RFC 5737)
66        || ipv4_in_cidr(ip, [198, 18, 0, 0], /*prefix*/ 15) // Benchmarking (RFC 2544)
67        || ipv4_in_cidr(ip, [198, 51, 100, 0], /*prefix*/ 24) // TEST-NET-2 (RFC 5737)
68        || ipv4_in_cidr(ip, [203, 0, 113, 0], /*prefix*/ 24) // TEST-NET-3 (RFC 5737)
69        || ipv4_in_cidr(ip, [240, 0, 0, 0], /*prefix*/ 4) // Reserved (RFC 6890)
70}
71
72fn ipv4_in_cidr(ip: Ipv4Addr, base: [u8; 4], prefix: u8) -> bool {
73    let ip = u32::from(ip);
74    let base = u32::from(Ipv4Addr::from(base));
75    let mask = if prefix == 0 {
76        0
77    } else {
78        u32::MAX << (32 - prefix)
79    };
80    (ip & mask) == (base & mask)
81}
82
83fn is_non_public_ipv6(ip: Ipv6Addr) -> bool {
84    if let Some(v4) = ip.to_ipv4() {
85        return is_non_public_ipv4(v4) || ip.is_loopback();
86    }
87    // Treat anything that isn't globally routable as "local" for SSRF prevention. In particular:
88    //  - `::1` loopback
89    //  - `fc00::/7` unique-local (RFC 4193)
90    //  - `fe80::/10` link-local
91    //  - `::` unspecified
92    //  - multicast ranges
93    ip.is_loopback()
94        || ip.is_unspecified()
95        || ip.is_multicast()
96        || ip.is_unique_local()
97        || ip.is_unicast_link_local()
98}
99
100/// Normalize host fragments for policy matching (trim whitespace, strip ports/brackets, lowercase).
101pub fn normalize_host(host: &str) -> String {
102    let host = host.trim();
103    if host.starts_with('[')
104        && let Some(end) = host.find(']')
105    {
106        return normalize_dns_host(&host[1..end]);
107    }
108
109    // The proxy stack should typically hand us a host without a port, but be
110    // defensive and strip `:port` when there is exactly one `:`.
111    if host.bytes().filter(|b| *b == b':').count() == 1 {
112        let host = host.split(':').next().unwrap_or_default();
113        return normalize_dns_host(host);
114    }
115
116    // Avoid mangling unbracketed IPv6 literals, but strip trailing dots so fully qualified domain
117    // names are treated the same as their dotless variants.
118    normalize_dns_host(host)
119}
120
121fn normalize_dns_host(host: &str) -> String {
122    let host = host.to_ascii_lowercase();
123    host.trim_end_matches('.').to_string()
124}
125
126fn normalize_pattern(pattern: &str) -> String {
127    let pattern = pattern.trim();
128    if pattern == "*" {
129        return "*".to_string();
130    }
131
132    let (prefix, remainder) = if let Some(domain) = pattern.strip_prefix("**.") {
133        ("**.", domain)
134    } else if let Some(domain) = pattern.strip_prefix("*.") {
135        ("*.", domain)
136    } else {
137        ("", pattern)
138    };
139
140    let remainder = normalize_host(remainder);
141    if prefix.is_empty() {
142        remainder
143    } else {
144        format!("{prefix}{remainder}")
145    }
146}
147
148pub(crate) fn is_global_wildcard_domain_pattern(pattern: &str) -> bool {
149    let normalized = normalize_pattern(pattern);
150    expand_domain_pattern(&normalized)
151        .iter()
152        .any(|candidate| candidate == "*")
153}
154
155#[derive(Clone, Copy, PartialEq, Eq)]
156enum GlobalWildcard {
157    Allow,
158    Reject,
159}
160
161pub(crate) fn compile_allowlist_globset(patterns: &[String]) -> Result<GlobSet> {
162    compile_globset_with_policy(patterns, GlobalWildcard::Allow)
163}
164
165pub(crate) fn compile_denylist_globset(patterns: &[String]) -> Result<GlobSet> {
166    compile_globset_with_policy(patterns, GlobalWildcard::Reject)
167}
168
169fn compile_globset_with_policy(
170    patterns: &[String],
171    global_wildcard: GlobalWildcard,
172) -> Result<GlobSet> {
173    let mut builder = GlobSetBuilder::new();
174    let mut seen = HashSet::new();
175    for pattern in patterns {
176        if global_wildcard == GlobalWildcard::Reject && is_global_wildcard_domain_pattern(pattern) {
177            bail!(
178                "unsupported global wildcard domain pattern \"*\"; use exact hosts or scoped wildcards like *.example.com or **.example.com"
179            );
180        }
181        let pattern = normalize_pattern(pattern);
182        // Supported domain patterns:
183        // - "example.com": match the exact host
184        // - "*.example.com": match any subdomain (not the apex)
185        // - "**.example.com": match the apex and any subdomain
186        // - "*": match every host when explicitly enabled for allowlist compilation
187        for candidate in expand_domain_pattern(&pattern) {
188            if !seen.insert(candidate.clone()) {
189                continue;
190            }
191            let glob = GlobBuilder::new(&candidate)
192                .case_insensitive(true)
193                .build()
194                .with_context(|| format!("invalid glob pattern: {candidate}"))?;
195            builder.add(glob);
196        }
197    }
198    Ok(builder.build()?)
199}
200
201#[derive(Debug, Clone)]
202pub(crate) enum DomainPattern {
203    ApexAndSubdomains(String),
204    SubdomainsOnly(String),
205    Exact(String),
206}
207
208impl DomainPattern {
209    /// Parse a policy pattern for constraint comparisons.
210    ///
211    /// Validation of glob syntax happens when building the globset; here we only
212    /// decode the wildcard prefixes to keep constraint checks lightweight.
213    pub(crate) fn parse(input: &str) -> Self {
214        let input = input.trim();
215        if input.is_empty() {
216            return Self::Exact(String::new());
217        }
218        if let Some(domain) = input.strip_prefix("**.") {
219            Self::parse_domain(domain, Self::ApexAndSubdomains)
220        } else if let Some(domain) = input.strip_prefix("*.") {
221            Self::parse_domain(domain, Self::SubdomainsOnly)
222        } else {
223            Self::Exact(input.to_string())
224        }
225    }
226
227    /// Parse a policy pattern for constraint comparisons, validating domain parts with `url`.
228    pub(crate) fn parse_for_constraints(input: &str) -> Self {
229        let input = input.trim();
230        if input.is_empty() {
231            return Self::Exact(String::new());
232        }
233        if let Some(domain) = input.strip_prefix("**.") {
234            return Self::ApexAndSubdomains(parse_domain_for_constraints(domain));
235        }
236        if let Some(domain) = input.strip_prefix("*.") {
237            return Self::SubdomainsOnly(parse_domain_for_constraints(domain));
238        }
239        Self::Exact(parse_domain_for_constraints(input))
240    }
241
242    fn parse_domain(domain: &str, build: impl FnOnce(String) -> Self) -> Self {
243        let domain = domain.trim();
244        if domain.is_empty() {
245            return Self::Exact(String::new());
246        }
247        build(domain.to_string())
248    }
249
250    pub(crate) fn allows(&self, candidate: &DomainPattern) -> bool {
251        match self {
252            DomainPattern::Exact(domain) => match candidate {
253                DomainPattern::Exact(candidate) => domain_eq(candidate, domain),
254                _ => false,
255            },
256            DomainPattern::SubdomainsOnly(domain) => match candidate {
257                DomainPattern::Exact(candidate) => is_strict_subdomain(candidate, domain),
258                DomainPattern::SubdomainsOnly(candidate) => {
259                    is_subdomain_or_equal(candidate, domain)
260                }
261                DomainPattern::ApexAndSubdomains(candidate) => {
262                    is_strict_subdomain(candidate, domain)
263                }
264            },
265            DomainPattern::ApexAndSubdomains(domain) => match candidate {
266                DomainPattern::Exact(candidate) => is_subdomain_or_equal(candidate, domain),
267                DomainPattern::SubdomainsOnly(candidate) => {
268                    is_subdomain_or_equal(candidate, domain)
269                }
270                DomainPattern::ApexAndSubdomains(candidate) => {
271                    is_subdomain_or_equal(candidate, domain)
272                }
273            },
274        }
275    }
276}
277
278fn parse_domain_for_constraints(domain: &str) -> String {
279    let domain = domain.trim().trim_end_matches('.');
280    if domain.is_empty() {
281        return String::new();
282    }
283    let host = if domain.starts_with('[') && domain.ends_with(']') {
284        &domain[1..domain.len().saturating_sub(1)]
285    } else {
286        domain
287    };
288    if host.contains('*') || host.contains('?') || host.contains('%') {
289        return domain.to_string();
290    }
291    match UrlHost::parse(host) {
292        Ok(host) => host.to_string(),
293        Err(_) => String::new(),
294    }
295}
296
297fn expand_domain_pattern(pattern: &str) -> Vec<String> {
298    match DomainPattern::parse(pattern) {
299        DomainPattern::Exact(domain) => vec![domain],
300        DomainPattern::SubdomainsOnly(domain) => {
301            vec![format!("?*.{domain}")]
302        }
303        DomainPattern::ApexAndSubdomains(domain) => {
304            vec![domain.clone(), format!("?*.{domain}")]
305        }
306    }
307}
308
309fn normalize_domain(domain: &str) -> String {
310    domain.trim_end_matches('.').to_ascii_lowercase()
311}
312
313fn domain_eq(left: &str, right: &str) -> bool {
314    normalize_domain(left) == normalize_domain(right)
315}
316
317fn is_subdomain_or_equal(child: &str, parent: &str) -> bool {
318    let child = normalize_domain(child);
319    let parent = normalize_domain(parent);
320    if child == parent {
321        return true;
322    }
323    child.ends_with(&format!(".{parent}"))
324}
325
326fn is_strict_subdomain(child: &str, parent: &str) -> bool {
327    let child = normalize_domain(child);
328    let parent = normalize_domain(parent);
329    child != parent && child.ends_with(&format!(".{parent}"))
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    use pretty_assertions::assert_eq;
337
338    #[test]
339    fn method_allowed_full_allows_everything() {
340        assert!(NetworkMode::Full.allows_method("GET"));
341        assert!(NetworkMode::Full.allows_method("POST"));
342        assert!(NetworkMode::Full.allows_method("CONNECT"));
343    }
344
345    #[test]
346    fn method_allowed_limited_allows_only_safe_methods() {
347        assert!(NetworkMode::Limited.allows_method("GET"));
348        assert!(NetworkMode::Limited.allows_method("HEAD"));
349        assert!(NetworkMode::Limited.allows_method("OPTIONS"));
350        assert!(!NetworkMode::Limited.allows_method("POST"));
351        assert!(!NetworkMode::Limited.allows_method("CONNECT"));
352    }
353
354    #[test]
355    fn compile_globset_normalizes_trailing_dots() {
356        let set = compile_denylist_globset(&["Example.COM.".to_string()]).unwrap();
357
358        assert_eq!(true, set.is_match("example.com"));
359        assert_eq!(false, set.is_match("api.example.com"));
360    }
361
362    #[test]
363    fn compile_globset_normalizes_wildcards() {
364        let set = compile_denylist_globset(&["*.Example.COM.".to_string()]).unwrap();
365
366        assert_eq!(true, set.is_match("api.example.com"));
367        assert_eq!(false, set.is_match("example.com"));
368    }
369
370    #[test]
371    fn compile_globset_supports_mid_label_wildcards() {
372        let set = compile_denylist_globset(&["region*.v2.argotunnel.com".to_string()]).unwrap();
373
374        assert_eq!(true, set.is_match("region1.v2.argotunnel.com"));
375        assert_eq!(true, set.is_match("region.v2.argotunnel.com"));
376        assert_eq!(false, set.is_match("xregion1.v2.argotunnel.com"));
377        assert_eq!(false, set.is_match("foo.region1.v2.argotunnel.com"));
378    }
379
380    #[test]
381    fn compile_globset_normalizes_apex_and_subdomains() {
382        let set = compile_denylist_globset(&["**.Example.COM.".to_string()]).unwrap();
383
384        assert_eq!(true, set.is_match("example.com"));
385        assert_eq!(true, set.is_match("api.example.com"));
386    }
387
388    #[test]
389    fn compile_globset_normalizes_bracketed_ipv6_literals() {
390        let set = compile_denylist_globset(&["[::1]".to_string()]).unwrap();
391
392        assert_eq!(true, set.is_match("::1"));
393    }
394
395    #[test]
396    fn is_loopback_host_handles_localhost_variants() {
397        assert!(is_loopback_host(&Host::parse("localhost").unwrap()));
398        assert!(is_loopback_host(&Host::parse("localhost.").unwrap()));
399        assert!(is_loopback_host(&Host::parse("LOCALHOST").unwrap()));
400        assert!(!is_loopback_host(&Host::parse("notlocalhost").unwrap()));
401    }
402
403    #[test]
404    fn is_loopback_host_handles_ip_literals() {
405        assert!(is_loopback_host(&Host::parse("127.0.0.1").unwrap()));
406        assert!(is_loopback_host(&Host::parse("::1").unwrap()));
407        assert!(!is_loopback_host(&Host::parse("1.2.3.4").unwrap()));
408    }
409
410    #[test]
411    fn is_non_public_ip_rejects_private_and_loopback_ranges() {
412        assert!(is_non_public_ip("127.0.0.1".parse().unwrap()));
413        assert!(is_non_public_ip("10.0.0.1".parse().unwrap()));
414        assert!(is_non_public_ip("192.168.0.1".parse().unwrap()));
415        assert!(is_non_public_ip("100.64.0.1".parse().unwrap()));
416        assert!(is_non_public_ip("192.0.0.1".parse().unwrap()));
417        assert!(is_non_public_ip("192.0.2.1".parse().unwrap()));
418        assert!(is_non_public_ip("198.18.0.1".parse().unwrap()));
419        assert!(is_non_public_ip("198.51.100.1".parse().unwrap()));
420        assert!(is_non_public_ip("203.0.113.1".parse().unwrap()));
421        assert!(is_non_public_ip("240.0.0.1".parse().unwrap()));
422        assert!(is_non_public_ip("0.1.2.3".parse().unwrap()));
423        assert!(!is_non_public_ip("8.8.8.8".parse().unwrap()));
424
425        assert!(is_non_public_ip("::ffff:127.0.0.1".parse().unwrap()));
426        assert!(is_non_public_ip("::ffff:10.0.0.1".parse().unwrap()));
427        assert!(!is_non_public_ip("::ffff:8.8.8.8".parse().unwrap()));
428
429        assert!(is_non_public_ip("::1".parse().unwrap()));
430        assert!(is_non_public_ip("fe80::1".parse().unwrap()));
431        assert!(is_non_public_ip("fc00::1".parse().unwrap()));
432    }
433
434    #[test]
435    fn normalize_host_lowercases_and_trims() {
436        assert_eq!(normalize_host("  ExAmPlE.CoM  "), "example.com");
437    }
438
439    #[test]
440    fn normalize_host_strips_port_for_host_port() {
441        assert_eq!(normalize_host("example.com:1234"), "example.com");
442    }
443
444    #[test]
445    fn normalize_host_preserves_unbracketed_ipv6() {
446        assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
447    }
448
449    #[test]
450    fn normalize_host_strips_trailing_dot() {
451        assert_eq!(normalize_host("example.com."), "example.com");
452        assert_eq!(normalize_host("ExAmPlE.CoM."), "example.com");
453    }
454
455    #[test]
456    fn normalize_host_strips_trailing_dot_with_port() {
457        assert_eq!(normalize_host("example.com.:443"), "example.com");
458    }
459
460    #[test]
461    fn normalize_host_strips_brackets_for_ipv6() {
462        assert_eq!(normalize_host("[::1]"), "::1");
463        assert_eq!(normalize_host("[::1]:443"), "::1");
464    }
465}