1use std::collections::HashMap;
10
11type NodeIdx = usize;
17
18const NIL: NodeIdx = usize::MAX;
20
21struct TrieNode<T> {
25 children: HashMap<u8, NodeIdx>,
27 output: Option<T>,
30 fail: NodeIdx,
32 dict: NodeIdx,
35}
36
37impl<T> TrieNode<T> {
38 fn new() -> Self {
39 Self {
40 children: HashMap::new(),
41 output: None,
42 fail: NIL,
43 dict: NIL,
44 }
45 }
46}
47
48pub struct KeywordClassifier<T> {
72 nodes: Vec<TrieNode<T>>,
74}
75
76impl<T: Clone> KeywordClassifier<T> {
77 pub fn new(rules: &[(&[&str], T)]) -> Self {
85 let mut classifier = Self {
86 nodes: vec![TrieNode::new()], };
88
89 for (keywords, cls) in rules {
90 for &kw in *keywords {
91 classifier.insert(kw, cls.clone());
92 }
93 }
94
95 classifier.build_failure_links();
96 classifier
97 }
98
99 pub fn classify(&self, text: &str) -> Option<&T> {
107 let mut node_idx: NodeIdx = 0; for byte in text.as_bytes() {
110 let ch = if byte.is_ascii_uppercase() {
112 byte | 0x20
113 } else {
114 *byte
115 };
116
117 while node_idx != 0 && !self.nodes[node_idx].children.contains_key(&ch) {
119 node_idx = self.nodes[node_idx].fail;
120 }
121
122 node_idx = self.nodes[node_idx]
123 .children
124 .get(&ch)
125 .copied()
126 .unwrap_or(0);
127
128 if let Some(ref out) = self.nodes[node_idx].output {
130 return Some(out);
131 }
132
133 let dict_idx = self.nodes[node_idx].dict;
135 if dict_idx != NIL {
136 if let Some(ref out) = self.nodes[dict_idx].output {
137 return Some(out);
138 }
139 }
140 }
141
142 None
143 }
144
145 fn insert(&mut self, word: &str, cls: T) {
150 let mut node_idx: NodeIdx = 0; for byte in word.as_bytes() {
153 let ch = if byte.is_ascii_uppercase() {
155 byte | 0x20
156 } else {
157 *byte
158 };
159
160 if let Some(&child_idx) = self.nodes[node_idx].children.get(&ch) {
161 node_idx = child_idx;
162 } else {
163 let child_idx = self.nodes.len();
164 self.nodes.push(TrieNode::new());
165 self.nodes[node_idx].children.insert(ch, child_idx);
166 node_idx = child_idx;
167 }
168 }
169
170 if self.nodes[node_idx].output.is_none() {
172 self.nodes[node_idx].output = Some(cls);
173 }
174 }
175
176 fn build_failure_links(&mut self) {
178 let mut queue: Vec<NodeIdx> = Vec::new();
181
182 let root_children: Vec<(u8, NodeIdx)> = self.nodes[0]
184 .children
185 .iter()
186 .map(|(&ch, &idx)| (ch, idx))
187 .collect();
188
189 for (_ch, child_idx) in &root_children {
190 self.nodes[*child_idx].fail = 0;
191 self.nodes[*child_idx].dict = NIL;
192 queue.push(*child_idx);
193 }
194
195 let mut head: usize = 0;
196
197 while head < queue.len() {
198 let node_idx = queue[head];
199 head += 1;
200
201 let children: Vec<(u8, NodeIdx)> = self.nodes[node_idx]
203 .children
204 .iter()
205 .map(|(&ch, &idx)| (ch, idx))
206 .collect();
207
208 for (ch, child_idx) in children {
209 let mut fail = self.nodes[node_idx].fail;
211 while fail != 0 && !self.nodes[fail].children.contains_key(&ch) {
212 fail = self.nodes[fail].fail;
213 }
214
215 let child_fail = self.nodes[fail]
216 .children
217 .get(&ch)
218 .copied()
219 .unwrap_or(0);
220
221 let child_fail = if child_fail == child_idx { 0 } else { child_fail };
223
224 self.nodes[child_idx].fail = child_fail;
225
226 self.nodes[child_idx].dict = if self.nodes[child_fail].output.is_some() {
228 child_fail
229 } else {
230 self.nodes[child_fail].dict
231 };
232
233 queue.push(child_idx);
234 }
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn basic_classification() {
245 let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
246 (&["blocked", "403", "captcha"], "blocked"),
247 (&["timeout"], "transient"),
248 ]);
249
250 assert_eq!(classifier.classify("Error 403 Forbidden"), Some(&"blocked"));
251 assert_eq!(classifier.classify("Request timed out: timeout"), Some(&"transient"));
252 assert_eq!(classifier.classify("success"), None);
253 }
254
255 #[test]
256 fn case_insensitive() {
257 let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
258 (&["captcha"], "blocked"),
259 ]);
260
261 assert_eq!(classifier.classify("CAPTCHA detected"), Some(&"blocked"));
262 assert_eq!(classifier.classify("CaPtChA"), Some(&"blocked"));
263 assert_eq!(classifier.classify("captcha"), Some(&"blocked"));
264 }
265
266 #[test]
267 fn first_rule_wins() {
268 let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
269 (&["timeout"], "blocked"),
270 (&["timeout"], "transient"),
271 ]);
272
273 assert_eq!(classifier.classify("timeout error"), Some(&"blocked"));
275 }
276
277 #[test]
278 fn overlapping_patterns() {
279 let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
280 (&["bot detect", "bot protection"], "blocked"),
281 (&["err_connection_reset", "err_connection_closed"], "transient"),
282 ]);
283
284 assert_eq!(
285 classifier.classify("Detected bot detection script"),
286 Some(&"blocked")
287 );
288 assert_eq!(
289 classifier.classify("net::ERR_CONNECTION_RESET"),
290 Some(&"transient")
291 );
292 }
293
294 #[test]
295 fn no_match_returns_none() {
296 let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
297 (&["foo"], "a"),
298 (&["bar"], "b"),
299 ]);
300
301 assert_eq!(classifier.classify("baz qux"), None);
302 assert_eq!(classifier.classify(""), None);
303 }
304
305 #[test]
306 fn substring_matching() {
307 let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
308 (&["403"], "blocked"),
309 ]);
310
311 assert_eq!(classifier.classify("HTTP/1.1 403 Forbidden"), Some(&"blocked"));
312 }
313
314 #[test]
315 fn multiple_keywords_same_rule() {
316 let classifier: KeywordClassifier<i32> = KeywordClassifier::new(&[
317 (&["alpha", "beta", "gamma"], 1),
318 (&["delta", "epsilon"], 2),
319 ]);
320
321 assert_eq!(classifier.classify("testing beta value"), Some(&1));
322 assert_eq!(classifier.classify("epsilon result"), Some(&2));
323 assert_eq!(classifier.classify("zeta"), None);
324 }
325
326 #[test]
327 fn aho_corasick_shared_prefix() {
328 let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
330 (&["abcde"], "first"),
331 (&["bcd"], "second"),
332 ]);
333
334 assert_eq!(classifier.classify("xxbcdxx"), Some(&"second"));
336 assert_eq!(classifier.classify("abcde"), Some(&"second"));
339
340 let classifier2: KeywordClassifier<&str> = KeywordClassifier::new(&[
342 (&["xyz"], "first"),
343 (&["abc"], "second"),
344 ]);
345 assert_eq!(classifier2.classify("xxxyzxx"), Some(&"first"));
346 assert_eq!(classifier2.classify("xxabcxx"), Some(&"second"));
347 }
348
349 #[test]
350 fn real_world_error_messages() {
351 #[derive(Clone, Debug, PartialEq)]
352 enum ErrorClass {
353 Blocked,
354 Auth,
355 BackendDown,
356 Transient,
357 }
358
359 let classifier: KeywordClassifier<ErrorClass> = KeywordClassifier::new(&[
360 (
361 &[
362 "bot detect", "blocked", "403", "captcha",
363 "checking your browser", "access denied",
364 ],
365 ErrorClass::Blocked,
366 ),
367 (&["401", "unauthorized"], ErrorClass::Auth),
368 (
369 &["backend unavailable", "503", "service unavailable"],
370 ErrorClass::BackendDown,
371 ),
372 (
373 &["err_connection_reset", "timeout", "websocket closed"],
374 ErrorClass::Transient,
375 ),
376 ]);
377
378 assert_eq!(
379 classifier.classify("Error: 403 Forbidden - Access Denied"),
380 Some(&ErrorClass::Blocked)
381 );
382 assert_eq!(
383 classifier.classify("HTTP 401 Unauthorized"),
384 Some(&ErrorClass::Auth)
385 );
386 assert_eq!(
387 classifier.classify("503 Service Temporarily Unavailable"),
388 Some(&ErrorClass::BackendDown)
389 );
390 assert_eq!(
391 classifier.classify("net::ERR_CONNECTION_RESET at navigation"),
392 Some(&ErrorClass::Transient)
393 );
394 assert_eq!(
395 classifier.classify("Page loaded successfully"),
396 None
397 );
398 }
399}