Skip to main content

zendriver_interception/
host_matcher.rs

1//! Host-set matcher for the tracker/fingerprinter blocklist.
2//!
3//! [`HostMatcher`] holds a set of domain names and answers [`is_blocked`]
4//! queries with a suffix-on-dot walk: `a.b.evil.com` is blocked if
5//! `evil.com` (or any ancestor) is in the set.
6//!
7//! [`host_of`] extracts the host component from a raw URL string without
8//! requiring a URL-parsing dependency — it handles the common `scheme://host`
9//! prefix and strips any trailing port + path.
10//!
11//! [`is_blocked`]: HostMatcher::is_blocked
12
13use std::collections::HashSet;
14
15/// A compiled set of blocked host names with subdomain-walk semantics.
16///
17/// Constructed once from a domain list, then shared (via [`Arc`]) across
18/// rule evaluations per request.
19///
20/// [`Arc`]: std::sync::Arc
21#[derive(Debug, Clone)]
22pub struct HostMatcher {
23    /// Canonicalized (ASCII-lowercased) host names.
24    blocked: HashSet<String>,
25}
26
27impl HostMatcher {
28    /// Build a matcher from an iterator of domain strings.
29    ///
30    /// Each entry is lowercased and leading/trailing whitespace stripped.
31    /// Blank entries and those starting with `#` (comments) are silently
32    /// ignored so the caller can pass lines from a plain-text host list
33    /// directly.
34    pub fn new(domains: impl IntoIterator<Item = String>) -> Self {
35        let blocked = domains
36            .into_iter()
37            .map(|s| s.trim().trim_end_matches('.').to_ascii_lowercase())
38            .filter(|s| !s.is_empty() && !s.starts_with('#'))
39            .collect();
40        Self { blocked }
41    }
42
43    /// Number of distinct blocked host names.
44    pub fn len(&self) -> usize {
45        self.blocked.len()
46    }
47
48    /// Returns `true` if no hosts are blocked.
49    pub fn is_empty(&self) -> bool {
50        self.blocked.is_empty()
51    }
52
53    /// Returns `true` if `host` is blocked.
54    ///
55    /// Matching is suffix-on-dot: `a.b.evil.com` is blocked when `evil.com`
56    /// (or any ancestor up to the bare root) is in the set.
57    ///
58    /// - Exact match first.
59    /// - Then strips the leftmost label repeatedly until a match or exhausted.
60    ///
61    /// `host` is compared case-insensitively.
62    pub fn is_blocked(&self, host: &str) -> bool {
63        let host = host.trim().trim_end_matches('.').to_ascii_lowercase();
64        let mut cursor: &str = &host;
65        loop {
66            if self.blocked.contains(cursor) {
67                return true;
68            }
69            // Strip the leftmost label (up to and including the first dot).
70            match cursor.find('.') {
71                Some(pos) => cursor = &cursor[pos + 1..],
72                None => return false,
73            }
74        }
75    }
76}
77
78/// Extract the host component from a raw URL string.
79///
80/// Handles the common `scheme://authority/path` form: strips everything up
81/// to and including `://`, then takes the authority portion (up to the first
82/// `/`, `?`, or `#`). Trims a trailing port (`:NNN`).
83///
84/// Returns `None` when no `://` separator is present (malformed or
85/// non-standard URL).
86pub fn host_of(url: &str) -> Option<&str> {
87    // Find scheme-authority boundary.
88    let after_scheme = url.split_once("://")?.1;
89    // Authority ends at the first `/`, `?`, or `#`.
90    let authority = match after_scheme.find(['/', '?', '#']) {
91        Some(pos) => &after_scheme[..pos],
92        None => after_scheme,
93    };
94    // Strip userinfo (`user:pass@host`).
95    let host_and_port = match authority.rfind('@') {
96        Some(pos) => &authority[pos + 1..],
97        None => authority,
98    };
99    // Strip port, but be careful with IPv6 literals like `[::1]:8080`.
100    let host = if host_and_port.starts_with('[') {
101        // IPv6: authority is `[addr]` or `[addr]:port` — keep the brackets.
102        match host_and_port.find(']') {
103            Some(pos) => &host_and_port[..=pos],
104            None => host_and_port,
105        }
106    } else {
107        match host_and_port.rfind(':') {
108            Some(pos) => &host_and_port[..pos],
109            None => host_and_port,
110        }
111    };
112    Some(host)
113}
114
115#[cfg(test)]
116#[allow(clippy::panic, clippy::unwrap_used)]
117mod tests {
118    use super::*;
119
120    // --- HostMatcher ----------------------------------------------------------
121
122    #[test]
123    fn exact_match_blocked() {
124        let m = HostMatcher::new(["evil.com".to_string()]);
125        assert!(m.is_blocked("evil.com"));
126    }
127
128    #[test]
129    fn subdomain_of_listed_domain_is_blocked() {
130        let m = HostMatcher::new(["evil.com".to_string()]);
131        assert!(m.is_blocked("tracker.evil.com"));
132        assert!(m.is_blocked("a.b.tracker.evil.com"));
133    }
134
135    #[test]
136    fn unrelated_host_not_blocked() {
137        let m = HostMatcher::new(["evil.com".to_string()]);
138        assert!(!m.is_blocked("good.com"));
139        assert!(!m.is_blocked("notevil.com"));
140        // suffix match must be on a dot boundary
141        assert!(!m.is_blocked("totallyevil.com"));
142    }
143
144    #[test]
145    fn bare_root_in_set_blocks_all_subdomains() {
146        // if someone lists a TLD or bare root (unusual but valid)
147        let m = HostMatcher::new(["example.com".to_string()]);
148        assert!(m.is_blocked("example.com"));
149        assert!(m.is_blocked("sub.example.com"));
150    }
151
152    #[test]
153    fn case_insensitive_match() {
154        let m = HostMatcher::new(["Evil.Com".to_string()]);
155        assert!(m.is_blocked("EVIL.COM"));
156        assert!(m.is_blocked("Tracker.Evil.Com"));
157    }
158
159    #[test]
160    fn empty_matcher_blocks_nothing() {
161        let m = HostMatcher::new(std::iter::empty());
162        assert!(!m.is_blocked("evil.com"));
163    }
164
165    #[test]
166    fn comment_and_blank_lines_ignored() {
167        let lines = vec![
168            "# this is a comment".to_string(),
169            "".to_string(),
170            "  ".to_string(),
171            "tracker.example.com".to_string(),
172            "# another comment".to_string(),
173        ];
174        let m = HostMatcher::new(lines);
175        assert!(m.is_blocked("tracker.example.com"));
176        assert!(!m.is_blocked("example.com")); // only the exact entry, not parent
177    }
178
179    #[test]
180    fn single_label_host_no_infinite_loop() {
181        // A host with no dot should not match and must not loop forever.
182        let m = HostMatcher::new(["localhost".to_string()]);
183        assert!(m.is_blocked("localhost"));
184        // A different single-label host misses cleanly.
185        assert!(!m.is_blocked("otherhost"));
186    }
187
188    // --- host_of --------------------------------------------------------------
189
190    #[test]
191    fn host_of_simple_url() {
192        assert_eq!(host_of("https://example.com/path"), Some("example.com"));
193    }
194
195    #[test]
196    fn host_of_with_port() {
197        assert_eq!(host_of("http://example.com:8080/path"), Some("example.com"));
198    }
199
200    #[test]
201    fn host_of_no_path() {
202        assert_eq!(host_of("https://example.com"), Some("example.com"));
203    }
204
205    #[test]
206    fn host_of_with_query() {
207        assert_eq!(host_of("https://example.com?foo=bar"), Some("example.com"));
208    }
209
210    #[test]
211    fn host_of_with_fragment() {
212        assert_eq!(host_of("https://example.com#section"), Some("example.com"));
213    }
214
215    #[test]
216    fn host_of_missing_scheme_separator() {
217        assert_eq!(host_of("not-a-url"), None);
218    }
219
220    #[test]
221    fn host_of_ipv6() {
222        assert_eq!(host_of("https://[::1]:443/path"), Some("[::1]"));
223    }
224
225    #[test]
226    fn host_of_with_userinfo() {
227        assert_eq!(
228            host_of("https://user:pass@example.com/path"),
229            Some("example.com")
230        );
231    }
232
233    // --- Integration: host_of + HostMatcher -----------------------------------
234
235    #[test]
236    fn is_blocked_using_host_of() {
237        let m = HostMatcher::new(["fingerprinter.io".to_string()]);
238        let url = "https://cdn.fingerprinter.io/track.js?v=1";
239        let host = host_of(url).expect("host_of returned None");
240        assert!(m.is_blocked(host));
241    }
242}