1use crate::engine::MatchAlgorithm;
2use crate::{engine::MultiPatternEngine, variant::VariantDetector};
3use lru::LruCache;
4use rayon::prelude::*;
5use regex::Regex;
6use std::num::NonZero;
7use std::sync::{Arc, Mutex};
8use std::{
9 fs::File,
10 io::{self, BufRead, BufReader},
11 path::Path,
12};
13
14pub struct Filter {
16 engine: MultiPatternEngine, variant_detector: VariantDetector, noise: Regex, cache: Arc<Mutex<LruCache<String, Vec<String>>>>,
20 #[cfg(feature = "net")]
21 http_client: reqwest::blocking::Client, }
23
24impl Filter {
25 pub fn new() -> Self {
27 Self {
28 engine: MultiPatternEngine::new(None, &[]),
29 variant_detector: VariantDetector::new(),
30 noise: Regex::new(r"[^\w\s\u4e00-\u9fff]").unwrap(),
31 cache: Arc::new(Mutex::new(LruCache::new(NonZero::new(1000).unwrap()))), #[cfg(feature = "net")]
33 http_client: reqwest::blocking::Client::builder()
34 .timeout(std::time::Duration::from_secs(5))
35 .build()
36 .unwrap(),
37 }
38 }
39
40 fn check_cache(&self, text: &str) -> Option<Vec<String>> {
41 self.cache.lock().unwrap().get(text).cloned()
42 }
43
44 fn cache_result(&self, text: &str, results: &[String]) {
45 self.cache.lock().unwrap().put(text.to_string(), results.to_vec());
46 }
47
48 pub fn clear_cache(&self) {
50 self.cache.lock().unwrap().clear();
51 }
52
53 pub fn with_algorithm(algorithm: MatchAlgorithm) -> Self {
55 Self { engine: MultiPatternEngine::new(Some(algorithm), &[]), ..Self::new() }
56 }
57
58 pub fn with_default_dict() -> io::Result<Self> {
60 let mut filter = Self::new();
61 filter.load_word_dict("dict/dict.txt")?;
62 Ok(filter)
63 }
64
65 pub fn update_noise_pattern(&mut self, pattern: &str) {
67 self.noise = Regex::new(pattern).unwrap();
68 }
69
70 pub fn add_word(&mut self, word: &str) {
72 let patterns = {
73 let mut p = self.engine.get_patterns().to_vec();
74 p.push(word.to_string());
75 p
76 };
77 self.engine.rebuild(&patterns);
78 self.variant_detector.add_word(word);
79 }
80
81 pub fn add_words(&mut self, words: &[&str]) {
83 let mut patterns = self.engine.get_patterns().to_vec();
84 patterns.extend(words.iter().map(|s| s.to_string()));
85
86 self.engine.rebuild(&patterns);
87 for word in words {
88 self.variant_detector.add_word(word);
89 }
90 }
91
92 pub fn current_algorithm(&self) -> MatchAlgorithm {
94 self.engine.current_algorithm()
95 }
96
97 pub fn del_word(&mut self, word: &str) {
99 let patterns: Vec<_> = self.engine.get_patterns().iter().filter(|&w| w != word).cloned().collect();
100
101 self.engine.rebuild(&patterns);
102 }
103
104 pub fn del_words(&mut self, words: &[&str]) {
106 let word_set: std::collections::HashSet<_> = words.iter().collect();
107 let patterns: Vec<_> =
108 self.engine.get_patterns().iter().filter(|w| !word_set.contains(&w.as_str())).cloned().collect();
109
110 self.engine.rebuild(&patterns);
111 }
112
113 pub fn load_word_dict<P: AsRef<Path>>(&mut self, path: P) -> io::Result<()> {
115 let file = File::open(path)?;
116 self.load(BufReader::new(file))
117 }
118
119 pub fn load<R: BufRead>(&mut self, reader: R) -> io::Result<()> {
121 let words: Vec<_> = reader.lines().collect::<Result<_, _>>()?;
122 self.add_words(&words.iter().map(|s| s.as_str()).collect::<Vec<_>>());
123 Ok(())
124 }
125
126 #[cfg(feature = "net")]
128 pub fn load_net_word_dict(&mut self, url: &str) -> io::Result<()> {
129 let response = self.http_client.get(url).send().map_err(io::Error::other)?;
130
131 if !response.status().is_success() {
132 return Err(io::Error::other(format!("HTTP request failed: {}", response.status())));
133 }
134
135 let reader = BufReader::new(response);
136 self.load(reader)
137 }
138
139 pub fn find_in(&self, text: &str) -> (bool, String) {
141 let clean_text = self.remove_noise(text);
142
143 if let Some(word) = self.engine.find_first(&clean_text) {
145 return (true, word);
146 }
147
148 let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
150
151 if let Some(word) = self.variant_detector.detect(&clean_text, &patterns).first() {
152 return (true, word.to_string());
153 }
154
155 (false, String::new())
156 }
157
158 pub fn replace(&self, text: &str, replacement: char) -> String {
160 let clean_text = self.remove_noise(text);
161
162 let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
164 let variants = self.variant_detector.detect(&clean_text, &patterns);
165
166 let mut result = clean_text;
167
168 for pattern in self.engine.get_patterns() {
170 let repl_str = replacement.to_string().repeat(pattern.chars().count());
171 result = result.replace(pattern, &repl_str);
172 }
173
174 for variant in variants {
176 let repl_str = replacement.to_string().repeat(variant.chars().count());
177 result = result.replace(variant, &repl_str);
178 }
179
180 result
181 }
182
183 pub fn filter(&self, text: &str) -> String {
185 let clean_text = self.remove_noise(text);
186
187 let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
189 let variants = self.variant_detector.detect(&clean_text, &patterns);
190
191 let mut result = clean_text;
192
193 for pattern in self.engine.get_patterns() {
195 result = result.replace(pattern, "");
196 }
197
198 for variant in variants {
200 result = result.replace(variant, "");
201 }
202
203 result
204 }
205
206 pub fn validate(&self, text: &str) -> (bool, String) {
208 self.find_in(text)
209 }
210
211 pub fn remove_noise(&self, text: &str) -> String {
213 self.noise.replace_all(text, "").to_string()
214 }
215
216 pub fn get_noise_pattern(&self) -> &Regex {
218 &self.noise
219 }
220}
221
222impl Filter {
223 pub fn find_all(&self, text: &str) -> Vec<String> {
225 let clean_text = self.remove_noise(text);
226
227 if let Some(cached_result) = self.check_cache(&clean_text) {
229 return cached_result;
230 }
231
232 let results = if clean_text.len() > 1000 {
233 self.find_all_parallel(&clean_text)
235 } else {
236 self.find_all_sequential(&clean_text)
238 };
239
240 self.cache_result(&clean_text, &results);
242
243 results
244 }
245
246 fn find_all_parallel(&self, text: &str) -> Vec<String> {
248 let chunk_size = std::cmp::max(text.len() / rayon::current_num_threads(), 100);
249 let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
250
251 let engine_results: Vec<String> = text
253 .chars()
254 .collect::<Vec<_>>()
255 .par_chunks(chunk_size)
256 .flat_map(|chunk| {
257 let chunk_text: String = chunk.iter().collect();
258 self.engine.find_all(&chunk_text)
259 })
260 .collect();
261
262 let variant_results: Vec<String> = text
264 .split_whitespace()
265 .collect::<Vec<_>>()
266 .par_iter()
267 .map(|segment| self.variant_detector.detect(segment, &patterns))
268 .flatten()
269 .map(|s| s.to_string())
270 .collect();
271
272 let mut results = engine_results;
274 results.extend(variant_results);
275 self.deduplicate_and_sort(results)
276 }
277
278 fn find_all_sequential(&self, text: &str) -> Vec<String> {
280 let mut results = self.engine.find_all(text);
281 let patterns: Vec<_> = self.engine.get_patterns().iter().map(|s| s.as_str()).collect();
282
283 results.extend(self.variant_detector.detect(text, &patterns).into_iter().map(|s| s.to_string()));
285
286 self.deduplicate_and_sort(results)
287 }
288
289 fn deduplicate_and_sort(&self, mut results: Vec<String>) -> Vec<String> {
291 results.sort_unstable();
292 results.dedup();
293 results
294 }
295
296 pub fn find_all_batch(&self, texts: &[&str]) -> Vec<Vec<String>> {
298 texts.par_iter().map(|text| self.find_all(text)).collect()
299 }
300
301 pub fn find_all_layered(&self, text: &str) -> Vec<String> {
303 let clean_text = self.remove_noise(text);
304 let mut results = Vec::new();
305 let mut remaining_text = clean_text.clone();
306
307 let mut sorted_patterns = self.engine.get_patterns().to_vec();
309 sorted_patterns.sort_by_key(|b| std::cmp::Reverse(b.len()));
310
311 for pattern in &sorted_patterns {
313 if remaining_text.contains(pattern) {
314 results.push(pattern.clone());
315 remaining_text = remaining_text.replace(pattern, " ");
317 }
318 }
319
320 let patterns: Vec<_> = sorted_patterns.iter().map(|s| s.as_str()).collect();
322 results.extend(self.variant_detector.detect(&remaining_text, &patterns).into_iter().map(|s| s.to_string()));
323
324 self.deduplicate_and_sort(results)
325 }
326
327 pub fn find_all_streaming<R: BufRead>(&self, reader: R) -> io::Result<Vec<String>> {
329 let mut all_results = Vec::new();
330
331 for line in reader.lines() {
332 let line = line?;
333 let results = self.find_all(&line);
334 all_results.extend(results);
335 }
336
337 Ok(self.deduplicate_and_sort(all_results))
338 }
339}
340
341impl Default for Filter {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use std::io::Cursor;
351 #[test]
352 fn test_integration() {
353 let mut filter = Filter::new();
354 filter.add_words(&["赌博", "色情"]);
355
356 assert_eq!(filter.find_in("含有赌博"), (true, "赌博".to_string()));
358
359 assert_eq!(filter.find_in("含有 dubo"), (true, "赌博".to_string()));
361
362 assert_eq!(filter.replace("赌博 色情", '*'), "** **");
364
365 assert_eq!(filter.filter("赌博内容"), "内容");
367 }
368
369 #[test]
370 fn test_noise_handling() {
371 let mut filter = Filter::new();
372 filter.add_word("赌博");
373
374 assert_eq!(filter.remove_noise("赌 博"), "赌 博");
376
377 assert_eq!(filter.remove_noise("赌@#$博"), "赌博");
379 }
380
381 #[test]
382 fn test_replace_vs_filter() {
383 let mut filter = Filter::new();
384 filter.add_words(&["赌博", "色情"]);
385
386 let text = "这里有赌博和色情内容";
387
388 assert_eq!(filter.replace(text, '*'), "这里有**和**内容");
390
391 assert_eq!(filter.filter(text), "这里有和内容");
393 }
394
395 #[test]
396 fn test_variant_detection() {
397 let mut filter = Filter::new();
398 filter.add_word("测试");
399
400 assert_eq!(filter.find_in("ceshi"), (true, "测试".to_string()));
401 }
402
403 #[test]
404 fn test_algorithm_switch_one() {
405 let mut small = Filter::new();
407 small.add_words(&["a", "b", "c"]);
408 assert!(matches!(small.engine.current_algorithm(), MatchAlgorithm::WuManber));
409
410 let words: Vec<_> = (0..150).map(|i| format!("word{i}")).collect();
412 let mut medium = Filter::new();
413 medium.add_words(&words.iter().map(|s| s.as_str()).collect::<Vec<_>>());
414 println!("Medium current_algorithm: {:?}", medium.engine.current_algorithm());
415 assert!(matches!(medium.engine.current_algorithm(), MatchAlgorithm::AhoCorasick));
416 }
417
418 #[test]
419 fn test_io_operations() -> io::Result<()> {
420 let mut filter = Filter::new();
421 let cursor = Cursor::new("word1\nword2\nword3");
422 filter.load(cursor)?;
423
424 assert_eq!(filter.find_in("word2"), (true, "word2".to_string()));
425 Ok(())
426 }
427
428 #[test]
429 fn test_algorithm_recommendation() {
430 assert_eq!(MultiPatternEngine::recommend_algorithm(50), MatchAlgorithm::WuManber);
431 assert_eq!(MultiPatternEngine::recommend_algorithm(150), MatchAlgorithm::AhoCorasick);
432 assert_eq!(MultiPatternEngine::recommend_algorithm(15000), MatchAlgorithm::Regex);
433 }
434
435 #[test]
436 fn test_algorithm_switch() {
437 let mut small = Filter::new();
439 small.add_words(&["a", "b", "c"]);
440 println!("Small (3 words): {:?}", small.current_algorithm());
441 assert!(matches!(small.current_algorithm(), MatchAlgorithm::WuManber));
442
443 let words: Vec<_> = (0..150).map(|i| format!("word{i}")).collect();
445 let word_refs: Vec<&str> = words.iter().map(|s| s.as_str()).collect();
446
447 let mut medium = Filter::new();
448 medium.add_words(&word_refs);
449
450 println!("Medium (150 words): {:?}", medium.current_algorithm());
451 println!("Pattern count: {}", medium.engine.get_patterns().len());
452
453 let recommended = MultiPatternEngine::recommend_algorithm(150);
455 println!("Recommended algorithm for 150 words: {recommended:?}");
456
457 assert!(matches!(medium.current_algorithm(), MatchAlgorithm::AhoCorasick));
458 }
459}