1use crate::transforms::Transform;
17use std::collections::HashSet;
18use torsh_core::error::Result;
19
20#[derive(Debug, Clone, Default)]
22pub struct ToLowercase;
23
24impl Transform<String> for ToLowercase {
25 type Output = String;
26
27 fn transform(&self, input: String) -> Result<Self::Output> {
28 Ok(input.to_lowercase())
29 }
30}
31
32#[derive(Debug, Clone, Default)]
34pub struct RemovePunctuation;
35
36impl Transform<String> for RemovePunctuation {
37 type Output = String;
38
39 fn transform(&self, input: String) -> Result<Self::Output> {
40 Ok(input
41 .chars()
42 .filter(|c| !c.is_ascii_punctuation())
43 .collect())
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct Tokenize {
50 delimiter: String,
51}
52
53impl Tokenize {
54 pub fn new(delimiter: String) -> Self {
56 Self { delimiter }
57 }
58
59 pub fn whitespace() -> Self {
61 Self::new(" ".to_string())
62 }
63
64 pub fn any_whitespace() -> Self {
66 Self::new("".to_string()) }
68}
69
70impl Transform<String> for Tokenize {
71 type Output = Vec<String>;
72
73 fn transform(&self, input: String) -> Result<Self::Output> {
74 if self.delimiter.is_empty() {
75 Ok(input.split_whitespace().map(|s| s.to_string()).collect())
77 } else {
78 Ok(input
79 .split(&self.delimiter)
80 .map(|s| s.to_string())
81 .collect())
82 }
83 }
84}
85
86#[derive(Debug, Clone, Default)]
88pub struct TrimWhitespace;
89
90impl Transform<String> for TrimWhitespace {
91 type Output = String;
92
93 fn transform(&self, input: String) -> Result<Self::Output> {
94 Ok(input.trim().to_string())
95 }
96}
97
98#[derive(Debug, Clone, Default)]
100pub struct CollapseWhitespace;
101
102impl Transform<String> for CollapseWhitespace {
103 type Output = String;
104
105 fn transform(&self, input: String) -> Result<Self::Output> {
106 let mut result = String::with_capacity(input.len());
107 let mut prev_was_space = false;
108
109 for ch in input.chars() {
110 if ch.is_whitespace() {
111 if !prev_was_space {
112 result.push(' ');
113 prev_was_space = true;
114 }
115 } else {
116 result.push(ch);
117 prev_was_space = false;
118 }
119 }
120
121 Ok(result.trim().to_string())
122 }
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct RemoveNumbers;
128
129impl Transform<String> for RemoveNumbers {
130 type Output = String;
131
132 fn transform(&self, input: String) -> Result<Self::Output> {
133 Ok(input.chars().filter(|c| !c.is_ascii_digit()).collect())
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct RemoveStopwords {
140 stopwords: HashSet<String>,
141}
142
143impl RemoveStopwords {
144 pub fn english() -> Self {
146 let stopwords = [
147 "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "has", "he", "in",
148 "is", "it", "its", "of", "on", "that", "the", "to", "was", "were", "will", "with",
149 "the", "this", "but", "they", "have", "had", "what", "said", "each", "which", "their",
150 "time", "will", "about", "if", "up", "out", "many", "then", "them", "these", "so",
151 "some", "her", "would", "make", "like", "into", "him", "has", "two", "more", "go",
152 "no", "way", "could", "my", "than", "first", "been", "call", "who", "oil", "sit",
153 "now", "find", "down", "day", "did", "get", "come", "made", "may", "part", "over",
154 "new", "sound", "take", "only", "little", "work", "know", "place", "year", "live",
155 "me", "back", "give", "most", "very", "after", "thing", "our", "just", "name", "good",
156 "sentence", "man", "think", "say", "great", "where", "help", "through", "much",
157 "before", "line", "right", "too", "mean", "old", "any", "same", "tell", "boy",
158 "follow", "came", "want", "show", "also", "around", "form", "three", "small", "set",
159 "put", "end", "why", "again", "turn", "here", "off", "went", "old", "number", "great",
160 "tell", "men", "say", "small", "every", "found", "still", "between", "mea", "another",
161 "even", "why", "must", "big", "because", "does", "each", "how", "let", "might", "move",
162 "own", "seem", "such", "turn", "under", "well", "without", "see", "use",
163 ]
164 .iter()
165 .map(|&s| s.to_string())
166 .collect();
167
168 Self { stopwords }
169 }
170
171 pub fn new(stopwords: Vec<String>) -> Self {
173 Self {
174 stopwords: stopwords.into_iter().collect(),
175 }
176 }
177
178 pub fn add_stopword(&mut self, word: String) {
180 self.stopwords.insert(word.to_lowercase());
181 }
182
183 pub fn stopword_count(&self) -> usize {
185 self.stopwords.len()
186 }
187}
188
189impl Transform<Vec<String>> for RemoveStopwords {
190 type Output = Vec<String>;
191
192 fn transform(&self, input: Vec<String>) -> Result<Self::Output> {
193 Ok(input
194 .into_iter()
195 .filter(|word| !self.stopwords.contains(&word.to_lowercase()))
196 .collect())
197 }
198}
199
200#[derive(Debug, Clone, Default)]
205pub struct PorterStemmer;
206
207impl PorterStemmer {
208 fn is_vowel(word: &str, i: usize) -> bool {
210 if i >= word.len() {
211 return false;
212 }
213 let chars: Vec<char> = word.chars().collect();
214 let ch = chars[i];
215 if "aeiou".contains(ch) {
216 return true;
217 }
218 if ch == 'y' && i > 0 && !Self::is_vowel(word, i - 1) {
219 return true;
220 }
221 false
222 }
223
224 fn measure(&self, word: &str) -> usize {
226 let mut m = 0;
227 let len = word.len();
228 let mut i = 0;
229
230 while i < len && !Self::is_vowel(word, i) {
232 i += 1;
233 }
234
235 while i < len {
236 while i < len && Self::is_vowel(word, i) {
238 i += 1;
239 }
240 if i >= len {
241 break;
242 }
243 m += 1;
244
245 while i < len && !Self::is_vowel(word, i) {
247 i += 1;
248 }
249 }
250
251 m
252 }
253
254 fn ends_with(&self, word: &str, suffix: &str) -> bool {
256 word.ends_with(suffix)
257 }
258
259 fn replace_suffix(&self, word: &str, old_suffix: &str, new_suffix: &str) -> String {
261 if let Some(stem) = word.strip_suffix(old_suffix) {
262 format!("{stem}{new_suffix}")
263 } else {
264 word.to_string()
265 }
266 }
267
268 fn step1a(&self, word: &str) -> String {
270 if self.ends_with(word, "sses") {
271 self.replace_suffix(word, "sses", "ss")
272 } else if self.ends_with(word, "ies") {
273 self.replace_suffix(word, "ies", "i")
274 } else if self.ends_with(word, "ss") {
275 word.to_string()
276 } else if self.ends_with(word, "s") && word.len() > 1 {
277 self.replace_suffix(word, "s", "")
278 } else {
279 word.to_string()
280 }
281 }
282
283 fn step1b(&self, word: &str) -> String {
285 if self.ends_with(word, "eed") {
286 let stem = &word[..word.len() - 3];
287 if self.measure(stem) > 0 {
288 self.replace_suffix(word, "eed", "ee")
289 } else {
290 word.to_string()
291 }
292 } else if self.ends_with(word, "ed") {
293 let stem = &word[..word.len() - 2];
294 if self.contains_vowel(stem) {
295 let result = stem;
296 if self.ends_with(result, "at")
297 || self.ends_with(result, "bl")
298 || self.ends_with(result, "iz")
299 {
300 format!("{result}e")
301 } else {
302 result.to_string()
303 }
304 } else {
305 word.to_string()
306 }
307 } else if self.ends_with(word, "ing") {
308 let stem = &word[..word.len() - 3];
309 if self.contains_vowel(stem) {
310 stem.to_string()
311 } else {
312 word.to_string()
313 }
314 } else {
315 word.to_string()
316 }
317 }
318
319 fn contains_vowel(&self, word: &str) -> bool {
321 for i in 0..word.len() {
322 if Self::is_vowel(word, i) {
323 return true;
324 }
325 }
326 false
327 }
328}
329
330impl Transform<String> for PorterStemmer {
331 type Output = String;
332
333 fn transform(&self, input: String) -> Result<Self::Output> {
334 if input.len() <= 2 {
335 return Ok(input);
336 }
337
338 let word = input.to_lowercase();
339 let word = self.step1a(&word);
340 let word = self.step1b(&word);
341
342 Ok(word)
343 }
344}
345
346#[derive(Debug, Clone)]
348pub struct NGramGenerator {
349 n: usize,
350}
351
352impl NGramGenerator {
353 pub fn new(n: usize) -> Self {
355 assert!(n > 0, "N must be greater than 0");
356 Self { n }
357 }
358
359 pub fn bigram() -> Self {
361 Self::new(2)
362 }
363
364 pub fn trigram() -> Self {
366 Self::new(3)
367 }
368
369 pub fn unigram() -> Self {
371 Self::new(1)
372 }
373
374 pub fn n(&self) -> usize {
376 self.n
377 }
378}
379
380impl Transform<Vec<String>> for NGramGenerator {
381 type Output = Vec<String>;
382
383 fn transform(&self, input: Vec<String>) -> Result<Self::Output> {
384 if input.len() < self.n {
385 return Ok(Vec::new());
386 }
387
388 let mut ngrams = Vec::new();
389 for i in 0..=input.len() - self.n {
390 let ngram = input[i..i + self.n].join(" ");
391 ngrams.push(ngram);
392 }
393
394 Ok(ngrams)
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct FilterByLength {
401 min_length: usize,
402 max_length: Option<usize>,
403}
404
405impl FilterByLength {
406 pub fn new(min_length: usize, max_length: Option<usize>) -> Self {
408 Self {
409 min_length,
410 max_length,
411 }
412 }
413
414 pub fn min_only(min_length: usize) -> Self {
416 Self::new(min_length, None)
417 }
418
419 pub fn max_only(max_length: usize) -> Self {
421 Self::new(0, Some(max_length))
422 }
423
424 pub fn range(min_length: usize, max_length: usize) -> Self {
426 Self::new(min_length, Some(max_length))
427 }
428}
429
430impl Transform<Vec<String>> for FilterByLength {
431 type Output = Vec<String>;
432
433 fn transform(&self, input: Vec<String>) -> Result<Self::Output> {
434 Ok(input
435 .into_iter()
436 .filter(|word| {
437 let len = word.len();
438 len >= self.min_length && self.max_length.map_or(true, |max| len <= max)
439 })
440 .collect())
441 }
442}
443
444#[derive(Debug, Clone)]
446pub struct ReplacePattern {
447 pattern: String,
448 replacement: String,
449}
450
451impl ReplacePattern {
452 pub fn new(pattern: String, replacement: String) -> Self {
454 Self {
455 pattern,
456 replacement,
457 }
458 }
459
460 pub fn remove(pattern: String) -> Self {
462 Self::new(pattern, String::new())
463 }
464}
465
466impl Transform<String> for ReplacePattern {
467 type Output = String;
468
469 fn transform(&self, input: String) -> Result<Self::Output> {
470 Ok(input.replace(&self.pattern, &self.replacement))
471 }
472}
473
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
476pub enum CaseMode {
477 Lower,
479 Upper,
481 Title,
483}
484
485#[derive(Debug, Clone)]
487pub struct ChangeCase {
488 mode: CaseMode,
489}
490
491impl ChangeCase {
492 pub fn new(mode: CaseMode) -> Self {
494 Self { mode }
495 }
496
497 pub fn lower() -> Self {
499 Self::new(CaseMode::Lower)
500 }
501
502 pub fn upper() -> Self {
504 Self::new(CaseMode::Upper)
505 }
506
507 pub fn title() -> Self {
509 Self::new(CaseMode::Title)
510 }
511}
512
513impl Transform<String> for ChangeCase {
514 type Output = String;
515
516 fn transform(&self, input: String) -> Result<Self::Output> {
517 match self.mode {
518 CaseMode::Lower => Ok(input.to_lowercase()),
519 CaseMode::Upper => Ok(input.to_uppercase()),
520 CaseMode::Title => {
521 let mut result = String::with_capacity(input.len());
522 let mut capitalize_next = true;
523
524 for ch in input.chars() {
525 if ch.is_alphabetic() {
526 if capitalize_next {
527 result.push(ch.to_uppercase().next().unwrap_or(ch));
528 capitalize_next = false;
529 } else {
530 result.push(ch.to_lowercase().next().unwrap_or(ch));
531 }
532 } else {
533 result.push(ch);
534 capitalize_next = ch.is_whitespace();
535 }
536 }
537
538 Ok(result)
539 }
540 }
541 }
542}
543
544pub fn tokenize_whitespace() -> Tokenize {
548 Tokenize::whitespace()
549}
550
551pub fn tokenize(delimiter: &str) -> Tokenize {
553 Tokenize::new(delimiter.to_string())
554}
555
556pub fn remove_english_stopwords() -> RemoveStopwords {
558 RemoveStopwords::english()
559}
560
561pub fn porter_stemmer() -> PorterStemmer {
563 PorterStemmer
564}
565
566pub fn bigrams() -> NGramGenerator {
568 NGramGenerator::bigram()
569}
570
571pub fn trigrams() -> NGramGenerator {
573 NGramGenerator::trigram()
574}
575
576pub fn filter_by_length(min: usize, max: Option<usize>) -> FilterByLength {
578 FilterByLength::new(min, max)
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 #[test]
586 fn test_to_lowercase() {
587 let transform = ToLowercase;
588 assert_eq!(
589 transform.transform("Hello World".to_string()).unwrap(),
590 "hello world"
591 );
592 }
593
594 #[test]
595 fn test_remove_punctuation() {
596 let transform = RemovePunctuation;
597 assert_eq!(
598 transform.transform("Hello, World!".to_string()).unwrap(),
599 "Hello World"
600 );
601 }
602
603 #[test]
604 fn test_tokenize_whitespace() {
605 let transform = Tokenize::whitespace();
606 let result = transform.transform("hello world test".to_string()).unwrap();
607 assert_eq!(result, vec!["hello", "world", "test"]);
608 }
609
610 #[test]
611 fn test_tokenize_custom_delimiter() {
612 let transform = Tokenize::new(",".to_string());
613 let result = transform.transform("a,b,c".to_string()).unwrap();
614 assert_eq!(result, vec!["a", "b", "c"]);
615 }
616
617 #[test]
618 fn test_trim_whitespace() {
619 let transform = TrimWhitespace;
620 assert_eq!(
621 transform.transform(" hello world ".to_string()).unwrap(),
622 "hello world"
623 );
624 }
625
626 #[test]
627 fn test_collapse_whitespace() {
628 let transform = CollapseWhitespace;
629 assert_eq!(
630 transform
631 .transform("hello world test".to_string())
632 .unwrap(),
633 "hello world test"
634 );
635 }
636
637 #[test]
638 fn test_remove_numbers() {
639 let transform = RemoveNumbers;
640 assert_eq!(
641 transform.transform("hello123world456".to_string()).unwrap(),
642 "helloworld"
643 );
644 }
645
646 #[test]
647 fn test_remove_stopwords() {
648 let stopwords = RemoveStopwords::english();
649 let input = vec!["the".to_string(), "quick".to_string(), "brown".to_string()];
650 let result = stopwords.transform(input).unwrap();
651 assert_eq!(result, vec!["quick", "brown"]);
652 }
653
654 #[test]
655 fn test_porter_stemmer() {
656 let stemmer = PorterStemmer;
657
658 assert_eq!(stemmer.transform("running".to_string()).unwrap(), "runn");
659 assert_eq!(stemmer.transform("flies".to_string()).unwrap(), "fli");
660 assert_eq!(stemmer.transform("died".to_string()).unwrap(), "di");
661 assert_eq!(stemmer.transform("agreed".to_string()).unwrap(), "agree");
662 assert_eq!(stemmer.transform("sing".to_string()).unwrap(), "sing"); }
664
665 #[test]
666 fn test_ngram_generator() {
667 let words = vec![
668 "the".to_string(),
669 "quick".to_string(),
670 "brown".to_string(),
671 "fox".to_string(),
672 ];
673
674 let bigram = NGramGenerator::bigram();
676 let bigrams = bigram.transform(words.clone()).unwrap();
677 assert_eq!(bigrams, vec!["the quick", "quick brown", "brown fox"]);
678
679 let trigram = NGramGenerator::trigram();
681 let trigrams = trigram.transform(words).unwrap();
682 assert_eq!(trigrams, vec!["the quick brown", "quick brown fox"]);
683 }
684
685 #[test]
686 fn test_length_filter() {
687 let words = vec![
688 "a".to_string(),
689 "the".to_string(),
690 "quick".to_string(),
691 "brown".to_string(),
692 "foxes".to_string(),
693 ];
694
695 let filter = FilterByLength::new(3, Some(5));
696 let filtered = filter.transform(words).unwrap();
697 assert_eq!(filtered, vec!["the", "quick", "brown", "foxes"]);
698 }
699
700 #[test]
701 fn test_case_transforms() {
702 let text = "Hello World Test".to_string();
703
704 let lower = ChangeCase::lower();
705 assert_eq!(lower.transform(text.clone()).unwrap(), "hello world test");
706
707 let upper = ChangeCase::upper();
708 assert_eq!(upper.transform(text.clone()).unwrap(), "HELLO WORLD TEST");
709
710 let title = ChangeCase::title();
711 assert_eq!(
712 title.transform("hello world".to_string()).unwrap(),
713 "Hello World"
714 );
715 }
716
717 #[test]
718 fn test_replace_pattern() {
719 let replacer = ReplacePattern::new("world".to_string(), "universe".to_string());
720 assert_eq!(
721 replacer.transform("hello world".to_string()).unwrap(),
722 "hello universe"
723 );
724
725 let remover = ReplacePattern::remove("test ".to_string());
726 assert_eq!(
727 remover
728 .transform("test hello test world".to_string())
729 .unwrap(),
730 "hello world"
731 );
732 }
733
734 #[test]
735 fn test_convenience_functions() {
736 let _tokenizer = tokenize_whitespace();
737 let _custom_tokenizer = tokenize(",");
738 let _stopwords = remove_english_stopwords();
739 let _stemmer = porter_stemmer();
740 let _bigrams = bigrams();
741 let _trigrams = trigrams();
742 let _filter = filter_by_length(3, Some(10));
743 }
744}