1use crate::error::{Result, TextError};
7use crate::tokenize::Tokenizer;
8use crate::vocabulary::Vocabulary;
9use regex::Regex;
10use std::collections::{HashMap, HashSet};
11
12pub trait TokenFilter {
14 fn apply(&self, tokens: &[String]) -> Vec<String>;
16
17 fn filtertext(&self, text: &str, tokenizer: &dyn Tokenizer) -> Result<String> {
19 let tokens = tokenizer.tokenize(text)?;
20 let filtered = self.apply(&tokens);
21 Ok(filtered.join(" "))
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct LengthFilter {
28 pub min_length: usize,
30 pub max_length: usize,
32}
33
34impl Default for LengthFilter {
35 fn default() -> Self {
36 Self {
37 min_length: 1,
38 max_length: usize::MAX,
39 }
40 }
41}
42
43impl LengthFilter {
44 pub fn new(_min_length: usize, maxlength: usize) -> Self {
46 Self {
47 min_length: _min_length,
48 max_length: maxlength,
49 }
50 }
51
52 pub fn with_min_length(mut self, minlength: usize) -> Self {
54 self.min_length = minlength;
55 self
56 }
57
58 pub fn with_max_length(mut self, maxlength: usize) -> Self {
60 self.max_length = maxlength;
61 self
62 }
63}
64
65impl TokenFilter for LengthFilter {
66 fn apply(&self, tokens: &[String]) -> Vec<String> {
67 tokens
68 .iter()
69 .filter(|token| {
70 let len = token.chars().count(); len >= self.min_length && len <= self.max_length
72 })
73 .cloned()
74 .collect()
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct FrequencyFilter {
81 pub min_count: usize,
83 pub max_count: Option<usize>,
85 pub max_freq: Option<f64>,
87 token_counts: HashMap<String, usize>,
89 total_count: usize,
91}
92
93impl FrequencyFilter {
94 pub fn from_tokens_with_vocabulary(
96 tokens: &[String],
97 vocabulary: &Vocabulary,
98 min_count: usize,
99 ) -> Self {
100 let mut token_counts = HashMap::new();
101
102 for token in tokens {
104 if vocabulary.contains(token) {
105 *token_counts.entry(token.clone()).or_insert(0) += 1;
106 }
107 }
108
109 let total_count: usize = token_counts.values().sum();
110
111 Self {
112 min_count,
113 max_count: None,
114 max_freq: None,
115 token_counts,
116 total_count,
117 }
118 }
119
120 pub fn from_counts(_token_counts: HashMap<String, usize>, mincount: usize) -> Self {
122 let total_count = _token_counts.values().sum();
123
124 Self {
125 min_count: mincount,
126 max_count: None,
127 max_freq: None,
128 token_counts: _token_counts,
129 total_count,
130 }
131 }
132
133 pub fn learn_from_corpus(
135 texts: &[&str],
136 tokenizer: &dyn Tokenizer,
137 min_count: usize,
138 ) -> Result<Self> {
139 let mut counts = HashMap::new();
140 let mut total = 0;
141
142 for &text in texts {
143 let tokens = tokenizer.tokenize(text)?;
144 for token in tokens {
145 *counts.entry(token).or_insert(0) += 1;
146 total += 1;
147 }
148 }
149
150 Ok(Self {
151 min_count,
152 max_count: None,
153 max_freq: None,
154 token_counts: counts,
155 total_count: total,
156 })
157 }
158
159 pub fn with_max_count(mut self, maxcount: usize) -> Self {
161 self.max_count = Some(maxcount);
162 self
163 }
164
165 pub fn with_max_freq(mut self, maxfreq: f64) -> Result<Self> {
167 if !(0.0..=1.0).contains(&maxfreq) {
168 return Err(TextError::InvalidInput(
169 "max_freq must be between 0.0 and 1.0".to_string(),
170 ));
171 }
172
173 self.max_freq = Some(maxfreq);
174 Ok(self)
175 }
176}
177
178impl TokenFilter for FrequencyFilter {
179 fn apply(&self, tokens: &[String]) -> Vec<String> {
180 tokens
181 .iter()
182 .filter(|token| {
183 let count = self.token_counts.get(*token).copied().unwrap_or(0);
184
185 if count < self.min_count {
187 return false;
188 }
189
190 if let Some(max_count) = self.max_count {
192 if count > max_count {
193 return false;
194 }
195 }
196
197 if let Some(max_freq) = self.max_freq {
199 if self.total_count > 0 {
200 let freq = count as f64 / self.total_count as f64;
201 if freq > max_freq {
202 return false;
203 }
204 }
205 }
206
207 true
208 })
209 .cloned()
210 .collect()
211 }
212}
213
214#[derive(Debug, Clone)]
216pub struct RegexFilter {
217 pattern: Regex,
219 keep_matching: bool,
221}
222
223impl RegexFilter {
224 pub fn new(_pattern: &str, keepmatching: bool) -> Result<Self> {
226 match Regex::new(_pattern) {
227 Ok(regex) => Ok(Self {
228 pattern: regex,
229 keep_matching: keepmatching,
230 }),
231 Err(e) => Err(TextError::InvalidInput(format!(
232 "Invalid regex pattern: {e}"
233 ))),
234 }
235 }
236}
237
238impl TokenFilter for RegexFilter {
239 fn apply(&self, tokens: &[String]) -> Vec<String> {
240 tokens
241 .iter()
242 .filter(|token| {
243 let matches = self.pattern.is_match(token);
244 matches == self.keep_matching
245 })
246 .cloned()
247 .collect()
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct StopwordsFilter {
254 stopwords: HashSet<String>,
256 remove_stopwords: bool,
258}
259
260impl StopwordsFilter {
261 pub fn new(_stopwords: Vec<String>, removestopwords: bool) -> Self {
263 Self {
264 stopwords: _stopwords.into_iter().collect(),
265 remove_stopwords: removestopwords,
266 }
267 }
268
269 pub fn from_file(path: &str) -> Result<Self> {
271 use std::fs::File;
272 use std::io::{BufRead, BufReader};
273
274 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
275 let reader = BufReader::new(file);
276
277 let mut stopwords = HashSet::new();
278 for line in reader.lines() {
279 let word = line.map_err(|e| TextError::IoError(e.to_string()))?;
280 if !word.trim().is_empty() && !word.starts_with('#') {
281 stopwords.insert(word.trim().to_lowercase());
282 }
283 }
284
285 Ok(Self {
286 stopwords,
287 remove_stopwords: true,
288 })
289 }
290
291 pub fn remove_stopwords(mut self, remove: bool) -> Self {
293 self.remove_stopwords = remove;
294 self
295 }
296
297 pub fn add_stopwords(&mut self, words: &[String]) {
299 for word in words {
300 self.stopwords.insert(word.clone());
301 }
302 }
303
304 pub fn get_stopwords(&self) -> Vec<String> {
306 self.stopwords.iter().cloned().collect()
307 }
308}
309
310impl TokenFilter for StopwordsFilter {
311 fn apply(&self, tokens: &[String]) -> Vec<String> {
312 tokens
313 .iter()
314 .filter(|token| {
315 let is_stopword = self.stopwords.contains(&token.to_lowercase());
316 if self.remove_stopwords {
317 !is_stopword
318 } else {
319 is_stopword
320 }
321 })
322 .cloned()
323 .collect()
324 }
325}
326
327pub struct CompositeFilter {
329 filters: Vec<Box<dyn TokenFilter + Send + Sync>>,
331}
332
333impl CompositeFilter {
334 pub fn new() -> Self {
336 Self {
337 filters: Vec::new(),
338 }
339 }
340
341 pub fn add_filter<F>(&mut self, filter: F)
343 where
344 F: TokenFilter + Send + Sync + 'static,
345 {
346 self.filters.push(Box::new(filter));
347 }
348
349 pub fn with_filter<F>(mut self, filter: F) -> Self
351 where
352 F: TokenFilter + Send + Sync + 'static,
353 {
354 self.add_filter(filter);
355 self
356 }
357}
358
359impl Default for CompositeFilter {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365impl std::fmt::Debug for CompositeFilter {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 f.debug_struct("CompositeFilter")
368 .field("num_filters", &self.filters.len())
369 .finish()
370 }
371}
372
373impl Clone for CompositeFilter {
376 fn clone(&self) -> Self {
377 Self::new()
380 }
381}
382
383impl TokenFilter for CompositeFilter {
384 fn apply(&self, tokens: &[String]) -> Vec<String> {
385 let mut filtered = tokens.to_vec();
386
387 for filter in &self.filters {
388 filtered = filter.apply(&filtered);
389 }
390
391 filtered
392 }
393}
394
395pub struct CustomFilter<F>
397where
398 F: Fn(&str) -> bool + Send + Sync,
399{
400 predicate: F,
402}
403
404impl<F> CustomFilter<F>
405where
406 F: Fn(&str) -> bool + Send + Sync,
407{
408 pub fn new(predicate: F) -> Self {
410 Self { predicate }
411 }
412}
413
414impl<F> TokenFilter for CustomFilter<F>
415where
416 F: Fn(&str) -> bool + Send + Sync,
417{
418 fn apply(&self, tokens: &[String]) -> Vec<String> {
419 tokens
420 .iter()
421 .filter(|token| (self.predicate)(token))
422 .cloned()
423 .collect()
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::tokenize::WordTokenizer;
431
432 fn get_test_tokens() -> Vec<String> {
433 vec![
434 "the".to_string(),
435 "quick".to_string(),
436 "brown".to_string(),
437 "fox".to_string(),
438 "jumps".to_string(),
439 "over".to_string(),
440 "the".to_string(),
441 "lazy".to_string(),
442 "dog".to_string(),
443 ]
444 }
445
446 #[test]
447 fn test_length_filter() {
448 let tokens = get_test_tokens();
449
450 let filter = LengthFilter::new(4, usize::MAX);
452 let filtered = filter.apply(&tokens);
453
454 let mut sorted_filtered = filtered.clone();
456 sorted_filtered.sort();
457 assert_eq!(
458 sorted_filtered,
459 vec!["brown", "jumps", "lazy", "over", "quick"]
460 );
461
462 let filter = LengthFilter::new(3, 3);
464 let filtered = filter.apply(&tokens);
465
466 let mut sorted_filtered = filtered.clone();
468 sorted_filtered.sort();
469 assert_eq!(sorted_filtered, vec!["dog", "fox", "the", "the"]);
470 }
471
472 #[test]
473 fn test_frequency_filter() {
474 let tokens = get_test_tokens();
475
476 let mut counts = HashMap::new();
478 for token in &tokens {
479 *counts.entry(token.clone()).or_insert(0) += 1;
480 }
481
482 let filter = FrequencyFilter::from_counts(counts, 2);
484 let filtered = filter.apply(&tokens);
485
486 assert_eq!(filtered, vec!["the", "the"]);
488 }
489
490 #[test]
491 fn test_regex_filter() {
492 let tokens = get_test_tokens();
493
494 let filter = RegexFilter::new(r"^b", true).unwrap();
496 let filtered = filter.apply(&tokens);
497
498 assert_eq!(filtered, vec!["brown"]);
499
500 let test_tokens = vec![
503 "the".to_string(),
504 "jumps".to_string(),
505 "the".to_string(),
506 "lazy".to_string(),
507 ];
508
509 let filter = RegexFilter::new(r"o", false).unwrap();
511 let filtered = filter.apply(&test_tokens);
512
513 let mut sorted_filtered = filtered.clone();
515 sorted_filtered.sort();
516 let expected = vec!["jumps", "lazy", "the", "the"];
517 assert_eq!(sorted_filtered, expected);
518 }
519
520 #[test]
521 fn test_stopwords_filter() {
522 let tokens = get_test_tokens();
523
524 let stopwords = vec!["the".to_string(), "over".to_string()];
526
527 let filter = StopwordsFilter::new(stopwords, true);
529 let filtered = filter.apply(&tokens);
530
531 assert_eq!(
532 filtered,
533 vec!["quick", "brown", "fox", "jumps", "lazy", "dog"]
534 );
535 }
536
537 #[test]
538 fn test_composite_filter() {
539 let tokens = get_test_tokens();
540
541 let length_filter = LengthFilter::new(4, usize::MAX);
543 let regex_filter = RegexFilter::new(r"o", true).unwrap();
544
545 let composite = CompositeFilter::new()
547 .with_filter(length_filter)
548 .with_filter(regex_filter);
549
550 let filtered = composite.apply(&tokens);
551
552 assert_eq!(filtered, vec!["brown", "over"]);
554 }
555
556 #[test]
557 fn test_custom_filter() {
558 let tokens = get_test_tokens();
559
560 let filter = CustomFilter::new(|token: &str| token.contains('o'));
562
563 let filtered = filter.apply(&tokens);
564 let mut sorted_filtered = filtered.clone();
566 sorted_filtered.sort();
567 assert_eq!(sorted_filtered, vec!["brown", "dog", "fox", "over"]);
568 }
569
570 #[test]
571 fn test_filtertext() {
572 let text = "The quick brown fox jumps over the lazy dog";
573 let tokenizer = WordTokenizer::default();
574
575 let filter = LengthFilter::new(5, usize::MAX);
577 let filtered = filter.filtertext(text, &tokenizer).unwrap();
578
579 assert_eq!(filtered, "quick brown jumps");
580 }
581}