Skip to main content

spider_browser/retry/
keyword_classifier.rs

1//! Aho-Corasick keyword classifier -- O(n) multi-pattern substring matching.
2//!
3//! Scans the input string exactly **once** regardless of how many keywords
4//! exist. Returns the classification of the first matched keyword
5//! (priority-ordered by rule insertion order).
6//!
7//! Built at construction time -- zero per-call allocation or compilation.
8
9use std::collections::HashMap;
10
11/// Internal index for nodes in the arena-allocated trie.
12///
13/// Using an arena (flat `Vec<TrieNode>`) instead of `Box`-based pointers
14/// avoids per-node heap allocations, is more cache-friendly, and lets us
15/// use plain `usize` indices as fail/dict links without lifetimes.
16type NodeIdx = usize;
17
18/// Sentinel value meaning "no link" (equivalent to `None` for `Option<NodeIdx>`).
19const NIL: NodeIdx = usize::MAX;
20
21/// A single node in the Aho-Corasick trie.
22///
23/// `T` is the classification type (e.g. an enum variant).
24struct TrieNode<T> {
25    /// Transition map: lowercased byte -> child node index.
26    children: HashMap<u8, NodeIdx>,
27    /// Classification emitted when this node completes a keyword.
28    /// `None` if this node is not the end of any keyword.
29    output: Option<T>,
30    /// Failure link -- longest proper suffix that is also a prefix in the trie.
31    fail: NodeIdx,
32    /// Dictionary suffix link -- nearest ancestor node (via fail chain) that
33    /// has an output. Allows O(1) output checking per character.
34    dict: NodeIdx,
35}
36
37impl<T> TrieNode<T> {
38    fn new() -> Self {
39        Self {
40            children: HashMap::new(),
41            output: None,
42            fail: NIL,
43            dict: NIL,
44        }
45    }
46}
47
48/// Aho-Corasick based keyword classifier.
49///
50/// Given a set of rules `(keywords, classification)` ordered by priority,
51/// [`classify`](Self::classify) scans the input in a single O(n) pass and
52/// returns the classification of the **first** (highest-priority) keyword
53/// that matches as a substring. All matching is case-insensitive -- keywords
54/// are lowercased at insert time and input bytes are lowercased inline during
55/// the scan (no allocation).
56///
57/// # Example
58///
59/// ```
60/// use spider_browser::retry::keyword_classifier::KeywordClassifier;
61///
62/// let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
63///     (&["blocked", "captcha", "403"], "blocked"),
64///     (&["timeout", "err_connection_reset"], "transient"),
65/// ]);
66///
67/// assert_eq!(classifier.classify("Error 403 Forbidden"), Some(&"blocked"));
68/// assert_eq!(classifier.classify("ERR_CONNECTION_RESET"), Some(&"transient"));
69/// assert_eq!(classifier.classify("all good"), None);
70/// ```
71pub struct KeywordClassifier<T> {
72    /// Arena-allocated trie nodes. Index 0 is always the root.
73    nodes: Vec<TrieNode<T>>,
74}
75
76impl<T: Clone> KeywordClassifier<T> {
77    /// Build a new classifier from priority-ordered rules.
78    ///
79    /// Each rule is `(keywords, classification)`. Rules are checked in order;
80    /// if two rules contain overlapping keywords, the **first** rule's
81    /// classification wins.
82    ///
83    /// All keywords are stored lowercased internally.
84    pub fn new(rules: &[(&[&str], T)]) -> Self {
85        let mut classifier = Self {
86            nodes: vec![TrieNode::new()], // index 0 = root
87        };
88
89        for (keywords, cls) in rules {
90            for &kw in *keywords {
91                classifier.insert(kw, cls.clone());
92            }
93        }
94
95        classifier.build_failure_links();
96        classifier
97    }
98
99    /// Classify a string by scanning it once for all keywords.
100    ///
101    /// Returns the classification of the highest-priority matching keyword,
102    /// or `None` if no keyword matches.
103    ///
104    /// Runs in O(n) where n = `text.len()` with inline ASCII lowercasing
105    /// (no heap allocation).
106    pub fn classify(&self, text: &str) -> Option<&T> {
107        let mut node_idx: NodeIdx = 0; // start at root
108
109        for byte in text.as_bytes() {
110            // Inline ASCII lowercase: A-Z (0x41..=0x5A) -> a-z (0x61..=0x7A)
111            let ch = if byte.is_ascii_uppercase() {
112                byte | 0x20
113            } else {
114                *byte
115            };
116
117            // Follow failure links until we find a matching transition or reach root
118            while node_idx != 0 && !self.nodes[node_idx].children.contains_key(&ch) {
119                node_idx = self.nodes[node_idx].fail;
120            }
121
122            node_idx = self.nodes[node_idx]
123                .children
124                .get(&ch)
125                .copied()
126                .unwrap_or(0);
127
128            // Check output at this node
129            if let Some(ref out) = self.nodes[node_idx].output {
130                return Some(out);
131            }
132
133            // Check dictionary suffix link chain
134            let dict_idx = self.nodes[node_idx].dict;
135            if dict_idx != NIL {
136                if let Some(ref out) = self.nodes[dict_idx].output {
137                    return Some(out);
138                }
139            }
140        }
141
142        None
143    }
144
145    /// Insert a keyword into the trie with the given classification.
146    ///
147    /// First-rule-wins: if the terminal node already has an output, the
148    /// existing (higher-priority) classification is kept.
149    fn insert(&mut self, word: &str, cls: T) {
150        let mut node_idx: NodeIdx = 0; // root
151
152        for byte in word.as_bytes() {
153            // Store keywords lowercased
154            let ch = if byte.is_ascii_uppercase() {
155                byte | 0x20
156            } else {
157                *byte
158            };
159
160            if let Some(&child_idx) = self.nodes[node_idx].children.get(&ch) {
161                node_idx = child_idx;
162            } else {
163                let child_idx = self.nodes.len();
164                self.nodes.push(TrieNode::new());
165                self.nodes[node_idx].children.insert(ch, child_idx);
166                node_idx = child_idx;
167            }
168        }
169
170        // First rule wins -- do not overwrite a higher-priority classification.
171        if self.nodes[node_idx].output.is_none() {
172            self.nodes[node_idx].output = Some(cls);
173        }
174    }
175
176    /// Build Aho-Corasick failure and dictionary suffix links via BFS.
177    fn build_failure_links(&mut self) {
178        // Use a simple queue (VecDeque is fine, but a Vec with a head pointer
179        // is allocation-friendlier for the small BFS we do here).
180        let mut queue: Vec<NodeIdx> = Vec::new();
181
182        // Root's direct children: fail -> root (0), dict -> root sentinel
183        let root_children: Vec<(u8, NodeIdx)> = self.nodes[0]
184            .children
185            .iter()
186            .map(|(&ch, &idx)| (ch, idx))
187            .collect();
188
189        for (_ch, child_idx) in &root_children {
190            self.nodes[*child_idx].fail = 0;
191            self.nodes[*child_idx].dict = NIL;
192            queue.push(*child_idx);
193        }
194
195        let mut head: usize = 0;
196
197        while head < queue.len() {
198            let node_idx = queue[head];
199            head += 1;
200
201            // Collect children to avoid borrow issues with the arena vec.
202            let children: Vec<(u8, NodeIdx)> = self.nodes[node_idx]
203                .children
204                .iter()
205                .map(|(&ch, &idx)| (ch, idx))
206                .collect();
207
208            for (ch, child_idx) in children {
209                // Walk the failure chain to find the fail link for this child.
210                let mut fail = self.nodes[node_idx].fail;
211                while fail != 0 && !self.nodes[fail].children.contains_key(&ch) {
212                    fail = self.nodes[fail].fail;
213                }
214
215                let child_fail = self.nodes[fail]
216                    .children
217                    .get(&ch)
218                    .copied()
219                    .unwrap_or(0);
220
221                // Avoid self-loop
222                let child_fail = if child_fail == child_idx { 0 } else { child_fail };
223
224                self.nodes[child_idx].fail = child_fail;
225
226                // Dictionary suffix link: nearest node (via fail chain) with output.
227                self.nodes[child_idx].dict = if self.nodes[child_fail].output.is_some() {
228                    child_fail
229                } else {
230                    self.nodes[child_fail].dict
231                };
232
233                queue.push(child_idx);
234            }
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn basic_classification() {
245        let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
246            (&["blocked", "403", "captcha"], "blocked"),
247            (&["timeout"], "transient"),
248        ]);
249
250        assert_eq!(classifier.classify("Error 403 Forbidden"), Some(&"blocked"));
251        assert_eq!(classifier.classify("Request timed out: timeout"), Some(&"transient"));
252        assert_eq!(classifier.classify("success"), None);
253    }
254
255    #[test]
256    fn case_insensitive() {
257        let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
258            (&["captcha"], "blocked"),
259        ]);
260
261        assert_eq!(classifier.classify("CAPTCHA detected"), Some(&"blocked"));
262        assert_eq!(classifier.classify("CaPtChA"), Some(&"blocked"));
263        assert_eq!(classifier.classify("captcha"), Some(&"blocked"));
264    }
265
266    #[test]
267    fn first_rule_wins() {
268        let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
269            (&["timeout"], "blocked"),
270            (&["timeout"], "transient"),
271        ]);
272
273        // First rule (blocked) should win even though both match "timeout"
274        assert_eq!(classifier.classify("timeout error"), Some(&"blocked"));
275    }
276
277    #[test]
278    fn overlapping_patterns() {
279        let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
280            (&["bot detect", "bot protection"], "blocked"),
281            (&["err_connection_reset", "err_connection_closed"], "transient"),
282        ]);
283
284        assert_eq!(
285            classifier.classify("Detected bot detection script"),
286            Some(&"blocked")
287        );
288        assert_eq!(
289            classifier.classify("net::ERR_CONNECTION_RESET"),
290            Some(&"transient")
291        );
292    }
293
294    #[test]
295    fn no_match_returns_none() {
296        let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
297            (&["foo"], "a"),
298            (&["bar"], "b"),
299        ]);
300
301        assert_eq!(classifier.classify("baz qux"), None);
302        assert_eq!(classifier.classify(""), None);
303    }
304
305    #[test]
306    fn substring_matching() {
307        let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
308            (&["403"], "blocked"),
309        ]);
310
311        assert_eq!(classifier.classify("HTTP/1.1 403 Forbidden"), Some(&"blocked"));
312    }
313
314    #[test]
315    fn multiple_keywords_same_rule() {
316        let classifier: KeywordClassifier<i32> = KeywordClassifier::new(&[
317            (&["alpha", "beta", "gamma"], 1),
318            (&["delta", "epsilon"], 2),
319        ]);
320
321        assert_eq!(classifier.classify("testing beta value"), Some(&1));
322        assert_eq!(classifier.classify("epsilon result"), Some(&2));
323        assert_eq!(classifier.classify("zeta"), None);
324    }
325
326    #[test]
327    fn aho_corasick_shared_prefix() {
328        // Test the failure link mechanism with overlapping prefixes.
329        let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
330            (&["abcde"], "first"),
331            (&["bcd"], "second"),
332        ]);
333
334        // "bcd" should match via failure links when it's the only pattern present.
335        assert_eq!(classifier.classify("xxbcdxx"), Some(&"second"));
336        // "bcd" completes at position 3 (inside "abcde") before "abcde" completes at position 4,
337        // so Aho-Corasick's first-match-in-scan-order returns "second".
338        assert_eq!(classifier.classify("abcde"), Some(&"second"));
339
340        // When shorter pattern isn't a substring, the longer one wins.
341        let classifier2: KeywordClassifier<&str> = KeywordClassifier::new(&[
342            (&["xyz"], "first"),
343            (&["abc"], "second"),
344        ]);
345        assert_eq!(classifier2.classify("xxxyzxx"), Some(&"first"));
346        assert_eq!(classifier2.classify("xxabcxx"), Some(&"second"));
347    }
348
349    #[test]
350    fn real_world_error_messages() {
351        #[derive(Clone, Debug, PartialEq)]
352        enum ErrorClass {
353            Blocked,
354            Auth,
355            BackendDown,
356            Transient,
357        }
358
359        let classifier: KeywordClassifier<ErrorClass> = KeywordClassifier::new(&[
360            (
361                &[
362                    "bot detect", "blocked", "403", "captcha",
363                    "checking your browser", "access denied",
364                ],
365                ErrorClass::Blocked,
366            ),
367            (&["401", "unauthorized"], ErrorClass::Auth),
368            (
369                &["backend unavailable", "503", "service unavailable"],
370                ErrorClass::BackendDown,
371            ),
372            (
373                &["err_connection_reset", "timeout", "websocket closed"],
374                ErrorClass::Transient,
375            ),
376        ]);
377
378        assert_eq!(
379            classifier.classify("Error: 403 Forbidden - Access Denied"),
380            Some(&ErrorClass::Blocked)
381        );
382        assert_eq!(
383            classifier.classify("HTTP 401 Unauthorized"),
384            Some(&ErrorClass::Auth)
385        );
386        assert_eq!(
387            classifier.classify("503 Service Temporarily Unavailable"),
388            Some(&ErrorClass::BackendDown)
389        );
390        assert_eq!(
391            classifier.classify("net::ERR_CONNECTION_RESET at navigation"),
392            Some(&ErrorClass::Transient)
393        );
394        assert_eq!(
395            classifier.classify("Page loaded successfully"),
396            None
397        );
398    }
399}