1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
11use sklears_core::{
12 error::{Result, SklearsError},
13 traits::{Fit, Transform},
14};
15use std::collections::{HashMap, HashSet};
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum NormalizationStrategy {
20 None,
22 Lowercase,
24 LowercaseNoPunct,
26 Full,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum TokenizationStrategy {
33 Whitespace,
35 WhitespacePunct,
37 Word,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum NgramType {
44 Char,
46 Word,
48}
49
50#[derive(Debug, Clone)]
52pub struct TextTokenizerConfig {
53 pub normalization: NormalizationStrategy,
54 pub tokenization: TokenizationStrategy,
55 pub min_token_length: usize,
56 pub max_token_length: usize,
57 pub stop_words: Option<HashSet<String>>,
58}
59
60impl Default for TextTokenizerConfig {
61 fn default() -> Self {
62 Self {
63 normalization: NormalizationStrategy::Lowercase,
64 tokenization: TokenizationStrategy::Word,
65 min_token_length: 1,
66 max_token_length: 50,
67 stop_words: None,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct TextTokenizer {
75 config: TextTokenizerConfig,
76}
77
78impl Default for TextTokenizer {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl TextTokenizer {
85 pub fn new() -> Self {
87 Self {
88 config: TextTokenizerConfig::default(),
89 }
90 }
91
92 pub fn with_config(config: TextTokenizerConfig) -> Self {
94 Self { config }
95 }
96
97 pub fn normalize(&self, text: &str) -> String {
99 match self.config.normalization {
100 NormalizationStrategy::None => text.to_string(),
101 NormalizationStrategy::Lowercase => text.to_lowercase(),
102 NormalizationStrategy::LowercaseNoPunct => text
103 .to_lowercase()
104 .chars()
105 .map(|c| {
106 if c.is_alphanumeric() || c.is_whitespace() {
107 c
108 } else {
109 ' '
110 }
111 })
112 .collect(),
113 NormalizationStrategy::Full => text
114 .to_lowercase()
115 .chars()
116 .map(|c| {
117 if c.is_alphanumeric() || c.is_whitespace() {
118 c
119 } else {
120 ' '
121 }
122 })
123 .collect::<String>()
124 .split_whitespace()
125 .collect::<Vec<_>>()
126 .join(" "),
127 }
128 }
129
130 pub fn tokenize(&self, text: &str) -> Vec<String> {
132 let normalized = self.normalize(text);
133
134 let tokens: Vec<String> = match self.config.tokenization {
135 TokenizationStrategy::Whitespace => normalized
136 .split_whitespace()
137 .map(|s| s.to_string())
138 .collect(),
139 TokenizationStrategy::WhitespacePunct => normalized
140 .split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
141 .filter(|s| !s.is_empty())
142 .map(|s| s.to_string())
143 .collect(),
144 TokenizationStrategy::Word => normalized
145 .chars()
146 .collect::<String>()
147 .split(|c: char| !c.is_alphanumeric())
148 .filter(|s| !s.is_empty())
149 .map(|s| s.to_string())
150 .collect(),
151 };
152
153 let mut filtered_tokens: Vec<String> = tokens
155 .into_iter()
156 .filter(|token| {
157 token.len() >= self.config.min_token_length
158 && token.len() <= self.config.max_token_length
159 })
160 .collect();
161
162 if let Some(ref stop_words) = self.config.stop_words {
164 filtered_tokens.retain(|token| !stop_words.contains(token));
165 }
166
167 filtered_tokens
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct TfIdfVectorizerConfig {
174 pub tokenizer_config: TextTokenizerConfig,
175 pub min_df: f64,
176 pub max_df: f64,
177 pub max_features: Option<usize>,
178 pub use_idf: bool,
179 pub smooth_idf: bool,
180 pub sublinear_tf: bool,
181}
182
183impl Default for TfIdfVectorizerConfig {
184 fn default() -> Self {
185 Self {
186 tokenizer_config: TextTokenizerConfig::default(),
187 min_df: 1.0,
188 max_df: 1.0,
189 max_features: None,
190 use_idf: true,
191 smooth_idf: true,
192 sublinear_tf: false,
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct TfIdfVectorizer {
200 config: TfIdfVectorizerConfig,
201 tokenizer: TextTokenizer,
202 vocabulary: HashMap<String, usize>,
203 idf_values: Array1<f64>,
204 fitted: bool,
205}
206
207impl Default for TfIdfVectorizer {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213impl TfIdfVectorizer {
214 pub fn new() -> Self {
216 let config = TfIdfVectorizerConfig::default();
217 let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
218
219 Self {
220 config,
221 tokenizer,
222 vocabulary: HashMap::new(),
223 idf_values: Array1::zeros(0),
224 fitted: false,
225 }
226 }
227
228 pub fn with_config(config: TfIdfVectorizerConfig) -> Self {
230 let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
231
232 Self {
233 config,
234 tokenizer,
235 vocabulary: HashMap::new(),
236 idf_values: Array1::zeros(0),
237 fitted: false,
238 }
239 }
240
241 fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
243 let mut term_doc_counts: HashMap<String, usize> = HashMap::new();
244 let n_docs = documents.len() as f64;
245
246 for document in documents {
248 let tokens = self.tokenizer.tokenize(document);
249 let unique_tokens: HashSet<String> = tokens.into_iter().collect();
250
251 for token in unique_tokens {
252 *term_doc_counts.entry(token).or_insert(0) += 1;
253 }
254 }
255
256 let min_df = if self.config.min_df < 1.0 {
258 (self.config.min_df * n_docs).ceil() as usize
259 } else {
260 self.config.min_df as usize
261 };
262
263 let max_df = if self.config.max_df < 1.0 {
264 (self.config.max_df * n_docs).floor() as usize
265 } else {
266 self.config.max_df as usize
267 };
268
269 let mut filtered_terms: Vec<(String, usize)> = term_doc_counts
270 .into_iter()
271 .filter(|(_, count)| *count >= min_df && *count <= max_df)
272 .collect();
273
274 filtered_terms.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
276
277 if let Some(max_features) = self.config.max_features {
279 filtered_terms.truncate(max_features);
280 }
281
282 for (idx, (term, _doc_freq)) in filtered_terms.iter().enumerate() {
284 self.vocabulary.insert(term.clone(), idx);
285 }
286
287 let vocab_size = self.vocabulary.len();
289 let mut idf_values = Array1::zeros(vocab_size);
290
291 if self.config.use_idf {
292 for &idx in self.vocabulary.values() {
293 let doc_freq = filtered_terms[idx].1 as f64;
294 let idf = if self.config.smooth_idf {
295 ((n_docs + 1.0) / (doc_freq + 1.0)).ln() + 1.0
296 } else {
297 (n_docs / doc_freq).ln() + 1.0
298 };
299 idf_values[idx] = idf;
300 }
301 } else {
302 idf_values.fill(1.0);
303 }
304
305 self.idf_values = idf_values;
306 Ok(())
307 }
308
309 fn transform_documents(&self, documents: &[String]) -> Result<Array2<f64>> {
311 if !self.fitted {
312 return Err(SklearsError::NotFitted {
313 operation: "TfIdfVectorizer not fitted".to_string(),
314 });
315 }
316
317 let n_docs = documents.len();
318 let vocab_size = self.vocabulary.len();
319 let mut tfidf_matrix = Array2::zeros((n_docs, vocab_size));
320
321 for (doc_idx, document) in documents.iter().enumerate() {
322 let tokens = self.tokenizer.tokenize(document);
323 let mut term_counts: HashMap<usize, f64> = HashMap::new();
324
325 for token in &tokens {
327 if let Some(&vocab_idx) = self.vocabulary.get(token) {
328 *term_counts.entry(vocab_idx).or_insert(0.0) += 1.0;
329 }
330 }
331
332 let total_terms = tokens.len() as f64;
334 for (vocab_idx, count) in term_counts {
335 let tf = if self.config.sublinear_tf {
336 1.0 + count.ln()
337 } else {
338 count / total_terms
339 };
340
341 let tfidf = tf * self.idf_values[vocab_idx];
342 tfidf_matrix[[doc_idx, vocab_idx]] = tfidf;
343 }
344 }
345
346 Ok(tfidf_matrix)
347 }
348
349 pub fn get_vocabulary(&self) -> &HashMap<String, usize> {
351 &self.vocabulary
352 }
353
354 pub fn get_idf_values(&self) -> ArrayView1<'_, f64> {
356 self.idf_values.view()
357 }
358}
359
360impl Fit<Vec<String>, ()> for TfIdfVectorizer {
361 type Fitted = TfIdfVectorizer;
362
363 fn fit(mut self, x: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
364 self.build_vocabulary(x)?;
365 self.fitted = true;
366 Ok(self)
367 }
368}
369
370impl Transform<Vec<String>, Array2<f64>> for TfIdfVectorizer {
371 fn transform(&self, x: &Vec<String>) -> Result<Array2<f64>> {
372 self.transform_documents(x)
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct NgramGeneratorConfig {
379 pub tokenizer_config: TextTokenizerConfig,
380 pub ngram_type: NgramType,
381 pub n_min: usize,
382 pub n_max: usize,
383}
384
385impl Default for NgramGeneratorConfig {
386 fn default() -> Self {
387 Self {
388 tokenizer_config: TextTokenizerConfig::default(),
389 ngram_type: NgramType::Word,
390 n_min: 1,
391 n_max: 2,
392 }
393 }
394}
395
396#[derive(Debug, Clone)]
398pub struct NgramGenerator {
399 config: NgramGeneratorConfig,
400 tokenizer: TextTokenizer,
401}
402
403impl Default for NgramGenerator {
404 fn default() -> Self {
405 Self::new()
406 }
407}
408
409impl NgramGenerator {
410 pub fn new() -> Self {
412 let config = NgramGeneratorConfig::default();
413 let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
414
415 Self { config, tokenizer }
416 }
417
418 pub fn with_config(config: NgramGeneratorConfig) -> Self {
420 let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
421 Self { config, tokenizer }
422 }
423
424 pub fn generate_ngrams(&self, text: &str) -> Vec<String> {
426 match self.config.ngram_type {
427 NgramType::Word => self.generate_word_ngrams(text),
428 NgramType::Char => self.generate_char_ngrams(text),
429 }
430 }
431
432 fn generate_word_ngrams(&self, text: &str) -> Vec<String> {
434 let tokens = self.tokenizer.tokenize(text);
435 let mut ngrams = Vec::new();
436
437 for n in self.config.n_min..=self.config.n_max {
438 if n > tokens.len() {
439 break;
440 }
441
442 for window in tokens.windows(n) {
443 let ngram = window.join(" ");
444 ngrams.push(ngram);
445 }
446 }
447
448 ngrams
449 }
450
451 fn generate_char_ngrams(&self, text: &str) -> Vec<String> {
453 let normalized = self.tokenizer.normalize(text);
454 let chars: Vec<char> = normalized.chars().collect();
455 let mut ngrams = Vec::new();
456
457 for n in self.config.n_min..=self.config.n_max {
458 if n > chars.len() {
459 break;
460 }
461
462 for window in chars.windows(n) {
463 let ngram: String = window.iter().collect();
464 ngrams.push(ngram);
465 }
466 }
467
468 ngrams
469 }
470}
471
472#[derive(Debug, Clone)]
474pub struct TextSimilarityConfig {
475 pub tokenizer_config: TextTokenizerConfig,
476 pub similarity_metric: SimilarityMetric,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq)]
481pub enum SimilarityMetric {
482 Cosine,
484 Jaccard,
486 Dice,
488}
489
490impl Default for TextSimilarityConfig {
491 fn default() -> Self {
492 Self {
493 tokenizer_config: TextTokenizerConfig::default(),
494 similarity_metric: SimilarityMetric::Cosine,
495 }
496 }
497}
498
499#[derive(Debug, Clone)]
501pub struct TextSimilarity {
502 config: TextSimilarityConfig,
503 tokenizer: TextTokenizer,
504}
505
506impl Default for TextSimilarity {
507 fn default() -> Self {
508 Self::new()
509 }
510}
511
512impl TextSimilarity {
513 pub fn new() -> Self {
515 let config = TextSimilarityConfig::default();
516 let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
517
518 Self { config, tokenizer }
519 }
520
521 pub fn with_config(config: TextSimilarityConfig) -> Self {
523 let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
524 Self { config, tokenizer }
525 }
526
527 pub fn similarity(&self, text1: &str, text2: &str) -> f64 {
529 match self.config.similarity_metric {
530 SimilarityMetric::Cosine => self.cosine_similarity(text1, text2),
531 SimilarityMetric::Jaccard => self.jaccard_similarity(text1, text2),
532 SimilarityMetric::Dice => self.dice_coefficient(text1, text2),
533 }
534 }
535
536 fn cosine_similarity(&self, text1: &str, text2: &str) -> f64 {
538 let tokens1 = self.tokenizer.tokenize(text1);
539 let tokens2 = self.tokenizer.tokenize(text2);
540
541 let mut term_freq1: HashMap<String, f64> = HashMap::new();
542 let mut term_freq2: HashMap<String, f64> = HashMap::new();
543
544 for token in tokens1 {
545 *term_freq1.entry(token).or_insert(0.0) += 1.0;
546 }
547
548 for token in tokens2 {
549 *term_freq2.entry(token).or_insert(0.0) += 1.0;
550 }
551
552 let mut dot_product = 0.0;
553 let mut norm1 = 0.0;
554 let mut norm2 = 0.0;
555
556 let all_terms: HashSet<String> = term_freq1
557 .keys()
558 .chain(term_freq2.keys())
559 .cloned()
560 .collect();
561
562 for term in all_terms {
563 let freq1 = term_freq1.get(&term).unwrap_or(&0.0);
564 let freq2 = term_freq2.get(&term).unwrap_or(&0.0);
565
566 dot_product += freq1 * freq2;
567 norm1 += freq1 * freq1;
568 norm2 += freq2 * freq2;
569 }
570
571 if norm1 == 0.0 || norm2 == 0.0 {
572 0.0
573 } else {
574 dot_product / (norm1.sqrt() * norm2.sqrt())
575 }
576 }
577
578 fn jaccard_similarity(&self, text1: &str, text2: &str) -> f64 {
580 let tokens1: HashSet<String> = self.tokenizer.tokenize(text1).into_iter().collect();
581 let tokens2: HashSet<String> = self.tokenizer.tokenize(text2).into_iter().collect();
582
583 let intersection = tokens1.intersection(&tokens2).count();
584 let union = tokens1.union(&tokens2).count();
585
586 if union == 0 {
587 0.0
588 } else {
589 intersection as f64 / union as f64
590 }
591 }
592
593 fn dice_coefficient(&self, text1: &str, text2: &str) -> f64 {
595 let tokens1: HashSet<String> = self.tokenizer.tokenize(text1).into_iter().collect();
596 let tokens2: HashSet<String> = self.tokenizer.tokenize(text2).into_iter().collect();
597
598 let intersection = tokens1.intersection(&tokens2).count();
599 let total = tokens1.len() + tokens2.len();
600
601 if total == 0 {
602 0.0
603 } else {
604 2.0 * intersection as f64 / total as f64
605 }
606 }
607}
608
609#[derive(Debug, Clone, Default)]
611pub struct BagOfWordsConfig {
612 pub tokenizer_config: TextTokenizerConfig,
613 pub max_features: Option<usize>,
614 pub binary: bool,
615}
616
617#[derive(Debug, Clone)]
619pub struct BagOfWordsEmbedding {
620 config: BagOfWordsConfig,
621 tokenizer: TextTokenizer,
622 vocabulary: HashMap<String, usize>,
623 fitted: bool,
624}
625
626impl Default for BagOfWordsEmbedding {
627 fn default() -> Self {
628 Self::new()
629 }
630}
631
632impl BagOfWordsEmbedding {
633 pub fn new() -> Self {
635 let config = BagOfWordsConfig::default();
636 let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
637
638 Self {
639 config,
640 tokenizer,
641 vocabulary: HashMap::new(),
642 fitted: false,
643 }
644 }
645
646 pub fn with_config(config: BagOfWordsConfig) -> Self {
648 let tokenizer = TextTokenizer::with_config(config.tokenizer_config.clone());
649
650 Self {
651 config,
652 tokenizer,
653 vocabulary: HashMap::new(),
654 fitted: false,
655 }
656 }
657
658 fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
660 let mut term_counts: HashMap<String, usize> = HashMap::new();
661
662 for document in documents {
664 let tokens = self.tokenizer.tokenize(document);
665 for token in tokens {
666 *term_counts.entry(token).or_insert(0) += 1;
667 }
668 }
669
670 let mut sorted_terms: Vec<(String, usize)> = term_counts.into_iter().collect();
672 sorted_terms.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
673
674 if let Some(max_features) = self.config.max_features {
676 sorted_terms.truncate(max_features);
677 }
678
679 for (idx, (term, _)) in sorted_terms.iter().enumerate() {
681 self.vocabulary.insert(term.clone(), idx);
682 }
683
684 Ok(())
685 }
686
687 fn transform_documents(&self, documents: &[String]) -> Result<Array2<f64>> {
689 if !self.fitted {
690 return Err(SklearsError::NotFitted {
691 operation: "BagOfWordsEmbedding not fitted".to_string(),
692 });
693 }
694
695 let n_docs = documents.len();
696 let vocab_size = self.vocabulary.len();
697 let mut bow_matrix = Array2::zeros((n_docs, vocab_size));
698
699 for (doc_idx, document) in documents.iter().enumerate() {
700 let tokens = self.tokenizer.tokenize(document);
701 let mut term_counts: HashMap<usize, f64> = HashMap::new();
702
703 for token in &tokens {
705 if let Some(&vocab_idx) = self.vocabulary.get(token) {
706 *term_counts.entry(vocab_idx).or_insert(0.0) += 1.0;
707 }
708 }
709
710 for (vocab_idx, count) in term_counts {
712 let value = if self.config.binary { 1.0 } else { count };
713 bow_matrix[[doc_idx, vocab_idx]] = value;
714 }
715 }
716
717 Ok(bow_matrix)
718 }
719
720 pub fn get_vocabulary(&self) -> &HashMap<String, usize> {
722 &self.vocabulary
723 }
724}
725
726impl Fit<Vec<String>, ()> for BagOfWordsEmbedding {
727 type Fitted = BagOfWordsEmbedding;
728
729 fn fit(mut self, x: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
730 self.build_vocabulary(x)?;
731 self.fitted = true;
732 Ok(self)
733 }
734}
735
736impl Transform<Vec<String>, Array2<f64>> for BagOfWordsEmbedding {
737 fn transform(&self, x: &Vec<String>) -> Result<Array2<f64>> {
738 self.transform_documents(x)
739 }
740}
741
742#[allow(non_snake_case)]
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use approx::assert_abs_diff_eq;
747
748 #[test]
749 fn test_text_tokenizer() {
750 let tokenizer = TextTokenizer::new();
751 let text = "Hello, World! This is a TEST.";
752 let tokens = tokenizer.tokenize(text);
753
754 assert_eq!(tokens, vec!["hello", "world", "this", "is", "a", "test"]);
755 }
756
757 #[test]
758 fn test_tfidf_vectorizer() {
759 let vectorizer = TfIdfVectorizer::new();
760 let documents = vec![
761 "the cat sat on the mat".to_string(),
762 "the dog ran in the park".to_string(),
763 "cats and dogs are pets".to_string(),
764 ];
765
766 let fitted_vectorizer = vectorizer.fit(&documents, &()).unwrap();
767 let tfidf_matrix = fitted_vectorizer.transform(&documents).unwrap();
768
769 assert_eq!(
770 tfidf_matrix.shape(),
771 &[3, fitted_vectorizer.vocabulary.len()]
772 );
773
774 for &value in tfidf_matrix.iter() {
776 assert!(value >= 0.0);
777 }
778 }
779
780 #[test]
781 fn test_ngram_generator() {
782 let generator = NgramGenerator::new();
783 let text = "the quick brown fox";
784 let ngrams = generator.generate_ngrams(text);
785
786 assert!(ngrams.contains(&"the".to_string()));
788 assert!(ngrams.contains(&"quick".to_string()));
789 assert!(ngrams.contains(&"the quick".to_string()));
790 assert!(ngrams.contains(&"quick brown".to_string()));
791 }
792
793 #[test]
794 fn test_text_similarity() {
795 let similarity = TextSimilarity::new();
796
797 let sim1 = similarity.similarity("the cat sat", "the cat sat");
799 assert_abs_diff_eq!(sim1, 1.0, epsilon = 1e-10);
800
801 let sim2 = similarity.similarity("the cat sat", "the dog ran");
802 assert!(sim2 > 0.0 && sim2 < 1.0);
803
804 let sim3 = similarity.similarity("hello world", "goodbye moon");
805 assert_eq!(sim3, 0.0);
806 }
807
808 #[test]
809 fn test_bag_of_words_embedding() {
810 let embedding = BagOfWordsEmbedding::new();
811 let documents = vec![
812 "the cat sat".to_string(),
813 "the dog ran".to_string(),
814 "cats and dogs".to_string(),
815 ];
816
817 let fitted_embedding = embedding.fit(&documents, &()).unwrap();
818 let bow_matrix = fitted_embedding.transform(&documents).unwrap();
819
820 assert_eq!(bow_matrix.shape(), &[3, fitted_embedding.vocabulary.len()]);
821
822 for &value in bow_matrix.iter() {
824 assert!(value >= 0.0);
825 }
826 }
827}