sensitive_rs/
filter.rs

1use crate::engine::MatchAlgorithm;
2use crate::{engine::MultiPatternEngine, variant::VariantDetector};
3use lru::LruCache;
4use rayon::prelude::*;
5use regex::Regex;
6use std::num::NonZero;
7use std::sync::{Arc, Mutex};
8use std::{
9    fs::File,
10    io::{self, BufRead, BufReader},
11    path::Path,
12};
13
14/// Advanced sensitive word filter with variant detection
15pub struct Filter {
16    engine: MultiPatternEngine,        // Multi-pattern matching engine
17    variant_detector: VariantDetector, // Variation detector
18    noise: Regex,                      // Noise processing rules
19    cache: Arc<Mutex<LruCache<String, Vec<String>>>>,
20    #[cfg(feature = "net")]
21    http_client: reqwest::blocking::Client, // Network request client
22}
23
24impl Filter {
25    /// Create a new filter with default settings
26    pub fn new() -> Self {
27        Self {
28            engine: MultiPatternEngine::new(None, &[]),
29            variant_detector: VariantDetector::new(),
30            noise: Regex::new(r"[^\w\s\u4e00-\u9fff]").unwrap(),
31            cache: Arc::new(Mutex::new(LruCache::new(NonZero::new(1000).unwrap()))), // Cache 1000 results
32            #[cfg(feature = "net")]
33            http_client: reqwest::blocking::Client::builder()
34                .timeout(std::time::Duration::from_secs(5))
35                .build()
36                .unwrap(),
37        }
38    }
39
40    fn check_cache(&self, text: &str) -> Option<Vec<String>> {
41        self.cache.lock().unwrap().get(text).cloned()
42    }
43
44    fn cache_result(&self, text: &str, results: &[String]) {
45        self.cache.lock().unwrap().put(text.to_string(), results.to_vec());
46    }
47
48    /// Clear the cache
49    pub fn clear_cache(&self) {
50        self.cache.lock().unwrap().clear();
51    }
52
53    /// Create with specific algorithm
54    pub fn with_algorithm(algorithm: MatchAlgorithm) -> Self {
55        Self { engine: MultiPatternEngine::new(Some(algorithm), &[]), ..Self::new() }
56    }
57
58    /// Load default dictionary
59    pub fn with_default_dict() -> io::Result<Self> {
60        let mut filter = Self::new();
61        filter.load_word_dict("dict/dict.txt")?;
62        Ok(filter)
63    }
64
65    /// Update noise pattern
66    pub fn update_noise_pattern(&mut self, pattern: &str) {
67        self.noise = Regex::new(pattern).unwrap();
68    }
69
70    /// Add a sensitive word
71    pub fn add_word(&mut self, word: &str) {
72        let patterns = {
73            let mut p = self.engine.get_patterns().to_vec();
74            p.push(word.to_string());
75            p
76        };
77        self.engine.rebuild(&patterns);
78        self.variant_detector.add_word(word);
79    }
80
81    /// Add multiple words
82    pub fn add_words(&mut self, words: &[&str]) {
83        let mut patterns = self.engine.get_patterns().to_vec();
84        patterns.extend(words.iter().map(|s| s.to_string()));
85
86        self.engine.rebuild(&patterns);
87        for word in words {
88            self.variant_detector.add_word(word);
89        }
90    }
91
92    /// Get the currently used algorithm
93    pub fn current_algorithm(&self) -> MatchAlgorithm {
94        self.engine.current_algorithm()
95    }
96
97    /// Remove a word
98    pub fn del_word(&mut self, word: &str) {
99        let patterns: Vec<_> = self.engine.get_patterns().iter().filter(|&w| w != word).cloned().collect();
100
101        self.engine.rebuild(&patterns);
102    }
103
104    /// Remove multiple words
105    pub fn del_words(&mut self, words: &[&str]) {
106        let word_set: std::collections::HashSet<_> = words.iter().collect();
107        let patterns: Vec<_> =
108            self.engine.get_patterns().iter().filter(|w| !word_set.contains(&w.as_str())).cloned().collect();
109
110        self.engine.rebuild(&patterns);
111    }
112
113    /// Load dictionary from file
114    pub fn load_word_dict<P: AsRef<Path>>(&mut self, path: P) -> io::Result<()> {
115        let file = File::open(path)?;
116        self.load(BufReader::new(file))
117    }
118
119    /// Load dictionary from reader
120    pub fn load<R: BufRead>(&mut self, reader: R) -> io::Result<()> {
121        let words: Vec<_> = reader.lines().collect::<Result<_, _>>()?;
122        self.add_words(&words.iter().map(|s| s.as_str()).collect::<Vec<_>>());
123        Ok(())
124    }
125
126    /// Load dictionary from URL
127    #[cfg(feature = "net")]
128    pub fn load_net_word_dict(&mut self, url: &str) -> io::Result<()> {
129        let response = self.http_client.get(url).send().map_err(io::Error::other)?;
130
131        if !response.status().is_success() {
132            return Err(io::Error::other(format!("HTTP request failed: {}", response.status())));
133        }
134
135        let reader = BufReader::new(response);
136        self.load(reader)
137    }
138
139    /// Find first sensitive word
140    pub fn find_in(&self, text: &str) -> (bool, String) {
141        let clean_text = self.remove_noise(text);
142
143        // 1. Try exact match first
144        if let Some(word) = self.engine.find_first(&clean_text) {
145            return (true, word);
146        }
147
148        // 2. Try variant detection
149        let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
150
151        if let Some(word) = self.variant_detector.detect(&clean_text, &patterns).first() {
152            return (true, word.to_string());
153        }
154
155        (false, String::new())
156    }
157
158    /// Replace sensitive words with replacement character
159    pub fn replace(&self, text: &str, replacement: char) -> String {
160        let clean_text = self.remove_noise(text);
161
162        // Get all sensitive words (including variants) that need to be processed
163        let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
164        let variants = self.variant_detector.detect(&clean_text, &patterns);
165
166        let mut result = clean_text;
167
168        // Replace sensitive words detected by the engine
169        for pattern in self.engine.get_patterns() {
170            let repl_str = replacement.to_string().repeat(pattern.chars().count());
171            result = result.replace(pattern, &repl_str);
172        }
173
174        // Replace the sensitive words detected by the variant
175        for variant in variants {
176            let repl_str = replacement.to_string().repeat(variant.chars().count());
177            result = result.replace(variant, &repl_str);
178        }
179
180        result
181    }
182
183    /// Filter out sensitive words (remove them completely)
184    pub fn filter(&self, text: &str) -> String {
185        let clean_text = self.remove_noise(text);
186
187        // Get all sensitive words (including variants) that need to be processed
188        let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
189        let variants = self.variant_detector.detect(&clean_text, &patterns);
190
191        let mut result = clean_text;
192
193        // Remove sensitive words detected by the engine
194        for pattern in self.engine.get_patterns() {
195            result = result.replace(pattern, "");
196        }
197
198        // Remove sensitive words detected by variants
199        for variant in variants {
200            result = result.replace(variant, "");
201        }
202
203        result
204    }
205
206    /// Validate text
207    pub fn validate(&self, text: &str) -> (bool, String) {
208        self.find_in(text)
209    }
210
211    /// Remove only specific noise characters, preserve spaces
212    pub fn remove_noise(&self, text: &str) -> String {
213        self.noise.replace_all(text, "").to_string()
214    }
215
216    /// Get current noise pattern
217    pub fn get_noise_pattern(&self) -> &Regex {
218        &self.noise
219    }
220}
221
222impl Filter {
223    /// Optimized method of finding all sensitive words
224    pub fn find_all(&self, text: &str) -> Vec<String> {
225        let clean_text = self.remove_noise(text);
226
227        // 1. Caching mechanism - Check whether the results have been cached
228        if let Some(cached_result) = self.check_cache(&clean_text) {
229            return cached_result;
230        }
231
232        let results = if clean_text.len() > 1000 {
233            // Long text is processed in parallel
234            self.find_all_parallel(&clean_text)
235        } else {
236            // General processing of short text
237            self.find_all_sequential(&clean_text)
238        };
239
240        // 3. Cache results
241        self.cache_result(&clean_text, &results);
242
243        results
244    }
245
246    /// Parallel Processing Version - For Long Text
247    fn find_all_parallel(&self, text: &str) -> Vec<String> {
248        let chunk_size = std::cmp::max(text.len() / rayon::current_num_threads(), 100);
249        let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
250
251        // Parallel processing in segments
252        let engine_results: Vec<String> = text
253            .chars()
254            .collect::<Vec<_>>()
255            .par_chunks(chunk_size)
256            .flat_map(|chunk| {
257                let chunk_text: String = chunk.iter().collect();
258                self.engine.find_all(&chunk_text)
259            })
260            .collect();
261
262        // Parallel variant detection - Fixed parallel iterator problem
263        let variant_results: Vec<String> = text
264            .split_whitespace()
265            .collect::<Vec<_>>()
266            .par_iter()
267            .map(|segment| self.variant_detector.detect(segment, &patterns))
268            .flatten()
269            .map(|s| s.to_string())
270            .collect();
271
272        // Merge and remove repetition
273        let mut results = engine_results;
274        results.extend(variant_results);
275        self.deduplicate_and_sort(results)
276    }
277
278    /// Sequential processing version - suitable for short text
279    fn find_all_sequential(&self, text: &str) -> Vec<String> {
280        let mut results = self.engine.find_all(text);
281        let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
282
283        // Add variant detection results
284        results.extend(self.variant_detector.detect(text, &patterns).into_iter().map(|s| s.to_string()));
285
286        self.deduplicate_and_sort(results)
287    }
288
289    /// Deduplication and sort
290    fn deduplicate_and_sort(&self, mut results: Vec<String>) -> Vec<String> {
291        results.sort_unstable();
292        results.dedup();
293        results
294    }
295
296    /// Bulk search for optimized versions
297    pub fn find_all_batch(&self, texts: &[&str]) -> Vec<Vec<String>> {
298        texts.par_iter().map(|text| self.find_all(text)).collect()
299    }
300
301    /// Hierarchical Matching - Preferential Matching by Sensitive Word Length
302    pub fn find_all_layered(&self, text: &str) -> Vec<String> {
303        let clean_text = self.remove_noise(text);
304        let mut results = Vec::new();
305        let mut remaining_text = clean_text.clone();
306
307        // Arrange patterns in descending order of length, prioritize long words
308        let mut sorted_patterns = self.engine.get_patterns().to_vec();
309        sorted_patterns.sort_by_key(|b| std::cmp::Reverse(b.len()));
310
311        // Hierarchical matching
312        for pattern in &sorted_patterns {
313            if remaining_text.contains(pattern) {
314                results.push(pattern.clone());
315                // Remove matching parts to avoid duplicate matches
316                remaining_text = remaining_text.replace(pattern, " ");
317            }
318        }
319
320        // Variation detection (for remaining text)
321        let patterns: Vec<_> = sorted_patterns.iter().map(|s| s.as_str()).collect();
322        results.extend(self.variant_detector.detect(&remaining_text, &patterns).into_iter().map(|s| s.to_string()));
323
324        self.deduplicate_and_sort(results)
325    }
326
327    /// Streaming version - suitable for oversized text
328    pub fn find_all_streaming<R: BufRead>(&self, reader: R) -> io::Result<Vec<String>> {
329        let mut all_results = Vec::new();
330
331        for line in reader.lines() {
332            let line = line?;
333            let results = self.find_all(&line);
334            all_results.extend(results);
335        }
336
337        Ok(self.deduplicate_and_sort(all_results))
338    }
339}
340
341impl Default for Filter {
342    fn default() -> Self {
343        Self::new()
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use std::io::Cursor;
351    #[test]
352    fn test_integration() {
353        let mut filter = Filter::new();
354        filter.add_words(&["赌博", "色情"]);
355
356        // Exact match
357        assert_eq!(filter.find_in("含有赌博"), (true, "赌博".to_string()));
358
359        // Pinyin variant
360        assert_eq!(filter.find_in("含有 dubo"), (true, "赌博".to_string()));
361
362        // Replacement
363        assert_eq!(filter.replace("赌博 色情", '*'), "** **");
364
365        // Filter
366        assert_eq!(filter.filter("赌博内容"), "内容");
367    }
368
369    #[test]
370    fn test_noise_handling() {
371        let mut filter = Filter::new();
372        filter.add_word("赌博");
373
374        // 测试空格保留
375        assert_eq!(filter.remove_noise("赌 博"), "赌 博");
376
377        // 测试特殊符号移除
378        assert_eq!(filter.remove_noise("赌@#$博"), "赌博");
379    }
380
381    #[test]
382    fn test_replace_vs_filter() {
383        let mut filter = Filter::new();
384        filter.add_words(&["赌博", "色情"]);
385
386        let text = "这里有赌博和色情内容";
387
388        // replace should be replaced with characters
389        assert_eq!(filter.replace(text, '*'), "这里有**和**内容");
390
391        // filter should be completely removed
392        assert_eq!(filter.filter(text), "这里有和内容");
393    }
394
395    #[test]
396    fn test_variant_detection() {
397        let mut filter = Filter::new();
398        filter.add_word("测试");
399
400        assert_eq!(filter.find_in("ceshi"), (true, "测试".to_string()));
401    }
402
403    #[test]
404    fn test_algorithm_switch_one() {
405        // Use Wu-Manber in small quantities
406        let mut small = Filter::new();
407        small.add_words(&["a", "b", "c"]);
408        assert!(matches!(small.engine.current_algorithm(), MatchAlgorithm::WuManber));
409
410        // Aho-Corasick for medium quantity
411        let words: Vec<_> = (0..150).map(|i| format!("word{i}")).collect();
412        let mut medium = Filter::new();
413        medium.add_words(&words.iter().map(|s| s.as_str()).collect::<Vec<_>>());
414        println!("Medium current_algorithm: {:?}", medium.engine.current_algorithm());
415        assert!(matches!(medium.engine.current_algorithm(), MatchAlgorithm::AhoCorasick));
416    }
417
418    #[test]
419    fn test_io_operations() -> io::Result<()> {
420        let mut filter = Filter::new();
421        let cursor = Cursor::new("word1\nword2\nword3");
422        filter.load(cursor)?;
423
424        assert_eq!(filter.find_in("word2"), (true, "word2".to_string()));
425        Ok(())
426    }
427
428    #[test]
429    fn test_algorithm_recommendation() {
430        assert_eq!(MultiPatternEngine::recommend_algorithm(50), MatchAlgorithm::WuManber);
431        assert_eq!(MultiPatternEngine::recommend_algorithm(150), MatchAlgorithm::AhoCorasick);
432        assert_eq!(MultiPatternEngine::recommend_algorithm(15000), MatchAlgorithm::Regex);
433    }
434
435    #[test]
436    fn test_algorithm_switch() {
437        // Use Wu-Manber in small quantities
438        let mut small = Filter::new();
439        small.add_words(&["a", "b", "c"]);
440        println!("Small (3 words): {:?}", small.current_algorithm());
441        assert!(matches!(small.current_algorithm(), MatchAlgorithm::WuManber));
442
443        // Aho-Corasick for medium quantity
444        let words: Vec<_> = (0..150).map(|i| format!("word{i}")).collect();
445        let word_refs: Vec<&str> = words.iter().map(|s| s.as_str()).collect();
446
447        let mut medium = Filter::new();
448        medium.add_words(&word_refs);
449
450        println!("Medium (150 words): {:?}", medium.current_algorithm());
451        println!("Pattern count: {}", medium.engine.get_patterns().len());
452
453        // Verification algorithm selection logic
454        let recommended = MultiPatternEngine::recommend_algorithm(150);
455        println!("Recommended algorithm for 150 words: {recommended:?}");
456
457        assert!(matches!(medium.current_algorithm(), MatchAlgorithm::AhoCorasick));
458    }
459}