1use crate::error::{Result, TextError};
64use crate::lemmatization::Lemmatizer;
65use crate::lemmatization::WordNetLemmatizer;
66use crate::stemming::{PorterStemmer, Stemmer};
67use scirs2_core::parallel_ops::*;
68use std::collections::HashSet;
69use std::sync::Arc;
70
71fn default_stopwords() -> HashSet<&'static str> {
76 [
77 "a",
78 "an",
79 "the",
80 "and",
81 "or",
82 "but",
83 "in",
84 "on",
85 "at",
86 "to",
87 "for",
88 "of",
89 "with",
90 "by",
91 "from",
92 "is",
93 "are",
94 "was",
95 "were",
96 "be",
97 "been",
98 "being",
99 "have",
100 "has",
101 "had",
102 "do",
103 "does",
104 "did",
105 "will",
106 "would",
107 "shall",
108 "should",
109 "may",
110 "might",
111 "can",
112 "could",
113 "it",
114 "its",
115 "this",
116 "that",
117 "these",
118 "those",
119 "i",
120 "me",
121 "my",
122 "myself",
123 "we",
124 "our",
125 "ours",
126 "ourselves",
127 "you",
128 "your",
129 "yours",
130 "yourself",
131 "yourselves",
132 "he",
133 "him",
134 "his",
135 "himself",
136 "she",
137 "her",
138 "hers",
139 "herself",
140 "they",
141 "them",
142 "their",
143 "theirs",
144 "themselves",
145 "what",
146 "which",
147 "who",
148 "whom",
149 "when",
150 "where",
151 "why",
152 "how",
153 "all",
154 "each",
155 "every",
156 "both",
157 "few",
158 "more",
159 "most",
160 "other",
161 "some",
162 "such",
163 "no",
164 "not",
165 "only",
166 "same",
167 "so",
168 "than",
169 "too",
170 "very",
171 "just",
172 "because",
173 "as",
174 "until",
175 "while",
176 "about",
177 "against",
178 "between",
179 "into",
180 "through",
181 "during",
182 "before",
183 "after",
184 "above",
185 "below",
186 "up",
187 "down",
188 "out",
189 "off",
190 "over",
191 "under",
192 "again",
193 "further",
194 "then",
195 "once",
196 "here",
197 "there",
198 "any",
199 "also",
200 "if",
201 "though",
202 "although",
203 "because",
204 "since",
205 "unless",
206 "whether",
207 "nor",
208 "neither",
209 "either",
210 "both",
211 "like",
212 "across",
213 "among",
214 "along",
215 "around",
216 "near",
217 "within",
218 "without",
219 "toward",
220 "towards",
221 "via",
222 "per",
223 "upon",
224 "onto",
225 "beside",
226 "besides",
227 "behind",
228 ]
229 .iter()
230 .copied()
231 .collect()
232}
233
234#[derive(Clone)]
240pub enum PipelineStep {
241 Tokenize,
247
248 Lowercase,
250
251 RemoveStopwords,
253
254 RemovePunctuation,
256
257 Stem,
259
260 Lemmatize,
262
263 NGrams(usize),
268
269 Custom(Arc<dyn Fn(Vec<String>) -> Vec<String> + Send + Sync>),
273}
274
275impl std::fmt::Debug for PipelineStep {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 match self {
278 PipelineStep::Tokenize => write!(f, "Tokenize"),
279 PipelineStep::Lowercase => write!(f, "Lowercase"),
280 PipelineStep::RemoveStopwords => write!(f, "RemoveStopwords"),
281 PipelineStep::RemovePunctuation => write!(f, "RemovePunctuation"),
282 PipelineStep::Stem => write!(f, "Stem"),
283 PipelineStep::Lemmatize => write!(f, "Lemmatize"),
284 PipelineStep::NGrams(n) => write!(f, "NGrams({n})"),
285 PipelineStep::Custom(_) => write!(f, "Custom(..)"),
286 }
287 }
288}
289
290pub struct NlpPipeline {
299 steps: Vec<PipelineStep>,
300 stopwords: HashSet<&'static str>,
301 stemmer: PorterStemmer,
302 lemmatizer: WordNetLemmatizer,
303}
304
305impl std::fmt::Debug for NlpPipeline {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 f.debug_struct("NlpPipeline")
308 .field("steps", &self.steps)
309 .finish()
310 }
311}
312
313impl NlpPipeline {
314 pub fn new(steps: Vec<PipelineStep>) -> Self {
318 Self {
319 steps,
320 stopwords: default_stopwords(),
321 stemmer: PorterStemmer::new(),
322 lemmatizer: WordNetLemmatizer::new(),
323 }
324 }
325
326 pub fn process(&self, text: &str) -> Result<Vec<String>> {
331 let mut tokens: Vec<String> = vec![text.to_string()];
334
335 for step in &self.steps {
336 tokens = self.apply_step(step, tokens)?;
337 }
338
339 Ok(tokens)
340 }
341
342 pub fn steps(&self) -> &[PipelineStep] {
344 &self.steps
345 }
346
347 fn apply_step(&self, step: &PipelineStep, tokens: Vec<String>) -> Result<Vec<String>> {
350 match step {
351 PipelineStep::Tokenize => self.step_tokenize(tokens),
352 PipelineStep::Lowercase => Ok(Self::step_lowercase(tokens)),
353 PipelineStep::RemoveStopwords => Ok(self.step_remove_stopwords(tokens)),
354 PipelineStep::RemovePunctuation => Ok(Self::step_remove_punctuation(tokens)),
355 PipelineStep::Stem => self.step_stem(tokens),
356 PipelineStep::Lemmatize => self.step_lemmatize(tokens),
357 PipelineStep::NGrams(n) => Self::step_ngrams(tokens, *n),
358 PipelineStep::Custom(f) => Ok(f(tokens)),
359 }
360 }
361
362 fn step_tokenize(&self, tokens: Vec<String>) -> Result<Vec<String>> {
365 let mut out = Vec::new();
366 for tok in tokens {
367 let words = extract_words(&tok);
369 out.extend(words);
370 }
371 Ok(out)
372 }
373
374 fn step_lowercase(tokens: Vec<String>) -> Vec<String> {
375 tokens.into_iter().map(|t| t.to_lowercase()).collect()
376 }
377
378 fn step_remove_stopwords(&self, tokens: Vec<String>) -> Vec<String> {
379 tokens
380 .into_iter()
381 .filter(|t| !self.stopwords.contains(t.to_lowercase().as_str()))
382 .collect()
383 }
384
385 fn step_remove_punctuation(tokens: Vec<String>) -> Vec<String> {
386 tokens
387 .into_iter()
388 .filter(|t| t.chars().any(|c| c.is_alphanumeric()))
389 .collect()
390 }
391
392 fn step_stem(&self, tokens: Vec<String>) -> Result<Vec<String>> {
393 tokens
394 .into_iter()
395 .map(|t| {
396 self.stemmer
397 .stem(&t)
398 .map_err(|e| TextError::ProcessingError(e.to_string()))
399 })
400 .collect()
401 }
402
403 fn step_lemmatize(&self, tokens: Vec<String>) -> Result<Vec<String>> {
404 tokens
405 .into_iter()
406 .map(|t| {
407 self.lemmatizer
408 .lemmatize(&t, None)
409 .map_err(|e| TextError::ProcessingError(e.to_string()))
410 })
411 .collect()
412 }
413
414 fn step_ngrams(tokens: Vec<String>, n: usize) -> Result<Vec<String>> {
415 if n == 0 {
416 return Err(TextError::InvalidInput("NGrams n must be >= 1".to_string()));
417 }
418 if n == 1 {
419 return Ok(tokens);
420 }
421 if tokens.len() < n {
422 return Ok(Vec::new());
423 }
424
425 let grams = tokens.windows(n).map(|window| window.join("_")).collect();
426
427 Ok(grams)
428 }
429}
430
431#[derive(Debug, Default)]
454pub struct PipelineBuilder {
455 steps: Vec<PipelineStep>,
456}
457
458impl PipelineBuilder {
459 pub fn new() -> Self {
461 Self { steps: Vec::new() }
462 }
463
464 pub fn add_step(mut self, step: PipelineStep) -> Self {
466 self.steps.push(step);
467 self
468 }
469
470 pub fn build(self) -> NlpPipeline {
472 NlpPipeline::new(self.steps)
473 }
474}
475
476pub struct BatchProcessor {
504 pipeline: Arc<NlpPipeline>,
505 parallel_threshold: usize,
507}
508
509impl BatchProcessor {
510 pub fn new(pipeline: NlpPipeline) -> Self {
512 Self {
513 pipeline: Arc::new(pipeline),
514 parallel_threshold: 32,
515 }
516 }
517
518 pub fn with_parallel_threshold(mut self, threshold: usize) -> Self {
521 self.parallel_threshold = threshold;
522 self
523 }
524
525 pub fn process_batch(&self, documents: &[&str]) -> Result<Vec<Vec<String>>> {
527 if documents.len() < self.parallel_threshold {
528 documents
530 .iter()
531 .map(|doc| self.pipeline.process(doc))
532 .collect()
533 } else {
534 let pipeline = Arc::clone(&self.pipeline);
536 let results: Vec<Result<Vec<String>>> = documents
537 .par_iter()
538 .map(|doc| pipeline.process(doc))
539 .collect();
540
541 results.into_iter().collect()
542 }
543 }
544
545 pub fn process_batch_tolerant(
548 &self,
549 documents: &[&str],
550 ) -> Vec<std::result::Result<Vec<String>, TextError>> {
551 if documents.len() < self.parallel_threshold {
552 documents
553 .iter()
554 .map(|doc| self.pipeline.process(doc))
555 .collect()
556 } else {
557 let pipeline = Arc::clone(&self.pipeline);
558 documents
559 .par_iter()
560 .map(|doc| pipeline.process(doc))
561 .collect()
562 }
563 }
564
565 pub fn pipeline(&self) -> &NlpPipeline {
567 &self.pipeline
568 }
569}
570
571pub fn basic_pipeline() -> NlpPipeline {
577 PipelineBuilder::new()
578 .add_step(PipelineStep::Tokenize)
579 .add_step(PipelineStep::Lowercase)
580 .add_step(PipelineStep::RemovePunctuation)
581 .add_step(PipelineStep::RemoveStopwords)
582 .build()
583}
584
585pub fn stemming_pipeline() -> NlpPipeline {
588 PipelineBuilder::new()
589 .add_step(PipelineStep::Tokenize)
590 .add_step(PipelineStep::Lowercase)
591 .add_step(PipelineStep::RemovePunctuation)
592 .add_step(PipelineStep::RemoveStopwords)
593 .add_step(PipelineStep::Stem)
594 .build()
595}
596
597pub fn lemmatization_pipeline() -> NlpPipeline {
600 PipelineBuilder::new()
601 .add_step(PipelineStep::Tokenize)
602 .add_step(PipelineStep::Lowercase)
603 .add_step(PipelineStep::RemovePunctuation)
604 .add_step(PipelineStep::RemoveStopwords)
605 .add_step(PipelineStep::Lemmatize)
606 .build()
607}
608
609pub fn ngram_pipeline(n: usize) -> NlpPipeline {
612 PipelineBuilder::new()
613 .add_step(PipelineStep::Tokenize)
614 .add_step(PipelineStep::Lowercase)
615 .add_step(PipelineStep::RemovePunctuation)
616 .add_step(PipelineStep::RemoveStopwords)
617 .add_step(PipelineStep::NGrams(n))
618 .build()
619}
620
621fn extract_words(text: &str) -> Vec<String> {
628 text.split_whitespace()
630 .filter_map(|raw| {
631 let trimmed: String = raw.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
633 if trimmed.is_empty() {
634 None
635 } else {
636 Some(trimmed)
637 }
638 })
639 .collect()
640}
641
642#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
653 fn test_builder_creates_pipeline() {
654 let pipeline = PipelineBuilder::new()
655 .add_step(PipelineStep::Tokenize)
656 .add_step(PipelineStep::Lowercase)
657 .build();
658 assert_eq!(pipeline.steps().len(), 2);
659 }
660
661 #[test]
662 fn test_tokenize_step() {
663 let pipeline = PipelineBuilder::new()
664 .add_step(PipelineStep::Tokenize)
665 .build();
666 let tokens = pipeline.process("hello world foo").unwrap();
667 assert_eq!(tokens, vec!["hello", "world", "foo"]);
668 }
669
670 #[test]
671 fn test_lowercase_step() {
672 let pipeline = PipelineBuilder::new()
673 .add_step(PipelineStep::Tokenize)
674 .add_step(PipelineStep::Lowercase)
675 .build();
676 let tokens = pipeline.process("Hello World FOO").unwrap();
677 assert_eq!(tokens, vec!["hello", "world", "foo"]);
678 }
679
680 #[test]
681 fn test_remove_punctuation_step() {
682 let pipeline = PipelineBuilder::new()
683 .add_step(PipelineStep::Tokenize)
684 .add_step(PipelineStep::RemovePunctuation)
685 .build();
686 let tokens = pipeline.process("Hello, world! This is a test.").unwrap();
687 assert!(tokens
690 .iter()
691 .all(|t| t.chars().any(|c| c.is_alphanumeric())));
692 }
693
694 #[test]
695 fn test_remove_stopwords_step() {
696 let pipeline = PipelineBuilder::new()
697 .add_step(PipelineStep::Tokenize)
698 .add_step(PipelineStep::Lowercase)
699 .add_step(PipelineStep::RemoveStopwords)
700 .build();
701 let tokens = pipeline
702 .process("the quick brown fox is a fast animal")
703 .unwrap();
704 assert!(!tokens.contains(&"the".to_string()));
706 assert!(!tokens.contains(&"is".to_string()));
707 assert!(!tokens.contains(&"a".to_string()));
708 assert!(tokens.contains(&"quick".to_string()));
710 assert!(tokens.contains(&"brown".to_string()));
711 assert!(tokens.contains(&"fox".to_string()));
712 }
713
714 #[test]
715 fn test_stem_step() {
716 let pipeline = PipelineBuilder::new()
717 .add_step(PipelineStep::Tokenize)
718 .add_step(PipelineStep::Lowercase)
719 .add_step(PipelineStep::Stem)
720 .build();
721 let tokens = pipeline.process("running dogs are jumping").unwrap();
722 assert!(tokens.contains(&"run".to_string()), "tokens: {tokens:?}");
724 assert!(tokens.contains(&"dog".to_string()), "tokens: {tokens:?}");
725 assert!(tokens.contains(&"jump".to_string()), "tokens: {tokens:?}");
726 }
727
728 #[test]
729 fn test_lemmatize_step() {
730 let pipeline = PipelineBuilder::new()
731 .add_step(PipelineStep::Tokenize)
732 .add_step(PipelineStep::Lowercase)
733 .add_step(PipelineStep::Lemmatize)
734 .build();
735 let tokens = pipeline.process("The cats went to the mice").unwrap();
736 assert!(
738 tokens.contains(&"cat".to_string()) || tokens.contains(&"cats".to_string()),
739 "tokens: {tokens:?}"
740 );
741 }
742
743 #[test]
744 fn test_ngrams_step() {
745 let pipeline = PipelineBuilder::new()
746 .add_step(PipelineStep::Tokenize)
747 .add_step(PipelineStep::Lowercase)
748 .add_step(PipelineStep::NGrams(2))
749 .build();
750 let tokens = pipeline.process("quick brown fox").unwrap();
751 assert_eq!(tokens, vec!["quick_brown", "brown_fox"]);
752 }
753
754 #[test]
755 fn test_ngrams_step_trigram() {
756 let pipeline = PipelineBuilder::new()
757 .add_step(PipelineStep::Tokenize)
758 .add_step(PipelineStep::NGrams(3))
759 .build();
760 let tokens = pipeline.process("a b c d").unwrap();
761 assert_eq!(tokens, vec!["a_b_c", "b_c_d"]);
762 }
763
764 #[test]
765 fn test_ngrams_invalid_n() {
766 let pipeline = PipelineBuilder::new()
767 .add_step(PipelineStep::Tokenize)
768 .add_step(PipelineStep::NGrams(0))
769 .build();
770 let result = pipeline.process("hello world");
771 assert!(result.is_err());
772 }
773
774 #[test]
775 fn test_ngrams_too_short() {
776 let pipeline = PipelineBuilder::new()
777 .add_step(PipelineStep::Tokenize)
778 .add_step(PipelineStep::NGrams(5))
779 .build();
780 let tokens = pipeline.process("hi").unwrap();
781 assert!(tokens.is_empty());
782 }
783
784 #[test]
785 fn test_custom_step() {
786 let pipeline = PipelineBuilder::new()
787 .add_step(PipelineStep::Tokenize)
788 .add_step(PipelineStep::Custom(Arc::new(|tokens| {
789 tokens.into_iter().filter(|t| t.len() > 3).collect()
790 })))
791 .build();
792 let tokens = pipeline.process("I am the quick brown fox").unwrap();
793 assert!(tokens.iter().all(|t| t.len() > 3));
794 }
795
796 #[test]
797 fn test_empty_input() {
798 let pipeline = PipelineBuilder::new()
799 .add_step(PipelineStep::Tokenize)
800 .add_step(PipelineStep::Lowercase)
801 .build();
802 let tokens = pipeline.process("").unwrap();
803 assert!(tokens.is_empty());
805 }
806
807 #[test]
808 fn test_full_pipeline() {
809 let pipeline = PipelineBuilder::new()
810 .add_step(PipelineStep::Tokenize)
811 .add_step(PipelineStep::Lowercase)
812 .add_step(PipelineStep::RemovePunctuation)
813 .add_step(PipelineStep::RemoveStopwords)
814 .add_step(PipelineStep::Stem)
815 .build();
816 let tokens = pipeline
817 .process("The quick brown foxes are jumping over the lazy dogs!")
818 .unwrap();
819 assert!(!tokens.contains(&"the".to_string()));
821 assert!(!tokens.contains(&"are".to_string()));
822 assert!(tokens.iter().any(|t| t == "fox" || t.starts_with("fox")));
824 assert!(!tokens.is_empty());
825 }
826
827 #[test]
830 fn test_batch_processor_basic() {
831 let pipeline = basic_pipeline();
832 let processor = BatchProcessor::new(pipeline);
833 let docs = vec![
834 "The quick brown fox",
835 "Hello world this is a test",
836 "Machine learning is fascinating",
837 ];
838 let results = processor.process_batch(&docs).unwrap();
839 assert_eq!(results.len(), 3);
840 for (doc, tokens) in docs.iter().zip(results.iter()) {
841 assert!(!tokens.is_empty(), "expected tokens for doc: {doc}");
842 }
843 }
844
845 #[test]
846 fn test_batch_processor_parallel() {
847 let pipeline = stemming_pipeline();
849 let processor = BatchProcessor::new(pipeline).with_parallel_threshold(0);
850
851 let docs: Vec<&str> = (0..100)
852 .map(|_| "running foxes jumping over lazy dogs")
853 .collect();
854
855 let results = processor.process_batch(&docs).unwrap();
856 assert_eq!(results.len(), 100);
857 for tokens in &results {
859 assert!(
860 tokens.iter().any(|t| t == "fox"),
861 "expected 'fox' in {tokens:?}"
862 );
863 }
864 }
865
866 #[test]
867 fn test_batch_processor_tolerant() {
868 let pipeline = basic_pipeline();
869 let processor = BatchProcessor::new(pipeline);
870 let docs = vec!["hello world", "the quick brown fox"];
871 let results = processor.process_batch_tolerant(&docs);
872 assert_eq!(results.len(), 2);
873 assert!(results.iter().all(|r| r.is_ok()));
874 }
875
876 #[test]
877 fn test_batch_processor_empty_doc() {
878 let pipeline = basic_pipeline();
879 let processor = BatchProcessor::new(pipeline);
880 let docs = vec!["", "hello world"];
881 let results = processor.process_batch(&docs).unwrap();
882 assert_eq!(results.len(), 2);
883 assert!(results[0].is_empty());
884 assert!(!results[1].is_empty());
885 }
886
887 #[test]
890 fn test_basic_pipeline_factory() {
891 let pipeline = basic_pipeline();
892 let tokens = pipeline.process("The fox is quick and agile").unwrap();
893 assert!(!tokens.contains(&"the".to_string()));
894 assert!(!tokens.contains(&"is".to_string()));
895 }
896
897 #[test]
898 fn test_stemming_pipeline_factory() {
899 let pipeline = stemming_pipeline();
900 let tokens = pipeline.process("The dogs are running fast").unwrap();
901 assert!(!tokens.contains(&"the".to_string()));
902 assert!(tokens.contains(&"dog".to_string()));
903 assert!(tokens.contains(&"run".to_string()));
904 }
905
906 #[test]
907 fn test_lemmatization_pipeline_factory() {
908 let pipeline = lemmatization_pipeline();
909 let tokens = pipeline.process("The mice went to sleep").unwrap();
910 assert!(!tokens.contains(&"the".to_string()));
911 assert!(tokens.contains(&"mouse".to_string()));
913 assert!(tokens.contains(&"go".to_string()));
915 }
916
917 #[test]
918 fn test_ngram_pipeline_factory() {
919 let pipeline = ngram_pipeline(2);
920 let tokens = pipeline.process("quick brown fox").unwrap();
921 for tok in &tokens {
923 assert!(tok.contains('_'), "expected bigram, got: {tok}");
924 }
925 }
926
927 #[test]
930 fn test_extract_words_strips_punctuation() {
931 let words = extract_words("Hello, world! Foo-bar.");
932 assert!(words.contains(&"Hello".to_string()), "{words:?}");
933 assert!(words.contains(&"world".to_string()), "{words:?}");
934 assert!(words.contains(&"Foo-bar".to_string()), "{words:?}");
935 }
936
937 #[test]
938 fn test_extract_words_empty() {
939 let words = extract_words(" ");
940 assert!(words.is_empty());
941 }
942}