scirs2_text/
enhanced_vectorize.rs

1//! Enhanced text vectorization with n-gram support
2//!
3//! This module provides enhanced vectorizers with n-gram support,
4//! additional preprocessing options, and IDF smoothing.
5
6use crate::error::{Result, TextError};
7use crate::tokenize::{NgramTokenizer, Tokenizer, WordTokenizer};
8use crate::vocabulary::Vocabulary;
9use scirs2_core::ndarray::{Array1, Array2};
10use std::collections::HashMap;
11
12/// Enhanced count vectorizer with n-gram support
13pub struct EnhancedCountVectorizer {
14    vocabulary: Vocabulary,
15    binary: bool,
16    ngram_range: (usize, usize),
17    max_features: Option<usize>,
18    min_df: f64,
19    max_df: f64,
20    lowercase: bool,
21}
22
23impl EnhancedCountVectorizer {
24    /// Create a new enhanced count vectorizer
25    pub fn new() -> Self {
26        Self {
27            vocabulary: Vocabulary::new(),
28            binary: false,
29            ngram_range: (1, 1),
30            max_features: None,
31            min_df: 0.0,
32            max_df: 1.0,
33            lowercase: true,
34        }
35    }
36
37    /// Set whether to produce binary vectors
38    pub fn set_binary(mut self, binary: bool) -> Self {
39        self.binary = binary;
40        self
41    }
42
43    /// Set the n-gram range (min_n, max_n)
44    pub fn set_ngram_range(mut self, range: (usize, usize)) -> Result<Self> {
45        if range.0 == 0 || range.1 < range.0 {
46            return Err(TextError::InvalidInput(
47                "Invalid n-gram range. Must have min_n > 0 and max_n >= min_n".to_string(),
48            ));
49        }
50        self.ngram_range = range;
51        Ok(self)
52    }
53
54    /// Set the maximum number of features
55    pub fn set_max_features(mut self, maxfeatures: Option<usize>) -> Self {
56        self.max_features = maxfeatures;
57        self
58    }
59
60    /// Set the minimum document frequency
61    pub fn set_min_df(mut self, mindf: f64) -> Result<Self> {
62        if !(0.0..=1.0).contains(&mindf) {
63            return Err(TextError::InvalidInput(
64                "min_df must be between 0.0 and 1.0".to_string(),
65            ));
66        }
67        self.min_df = mindf;
68        Ok(self)
69    }
70
71    /// Set the maximum document frequency
72    pub fn set_max_df(mut self, maxdf: f64) -> Result<Self> {
73        if !(0.0..=1.0).contains(&maxdf) {
74            return Err(TextError::InvalidInput(
75                "max_df must be between 0.0 and 1.0".to_string(),
76            ));
77        }
78        self.max_df = maxdf;
79        Ok(self)
80    }
81
82    /// Set whether to lowercase text
83    pub fn set_lowercase(mut self, lowercase: bool) -> Self {
84        self.lowercase = lowercase;
85        self
86    }
87
88    /// Get the vocabulary
89    pub fn vocabulary(&self) -> &Vocabulary {
90        &self.vocabulary
91    }
92
93    /// Fit the vectorizer on a corpus
94    pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
95        if texts.is_empty() {
96            return Err(TextError::InvalidInput(
97                "No texts provided for fitting".to_string(),
98            ));
99        }
100
101        // Clear existing vocabulary
102        self.vocabulary = Vocabulary::new();
103
104        // Track document frequencies
105        let mut doc_frequencies: HashMap<String, usize> = HashMap::new();
106        let total_docs = texts.len();
107
108        // Process each document
109        for text in texts {
110            let mut seen_in_doc: HashMap<String, bool> = HashMap::new();
111
112            // Extract all n-grams in the range
113            let all_tokens = self.extract_ngrams(text)?;
114
115            // Count document frequencies
116            for token in all_tokens {
117                if !seen_in_doc.contains_key(&token) {
118                    *doc_frequencies.entry(token.clone()).or_insert(0) += 1;
119                    seen_in_doc.insert(token.clone(), true);
120                }
121
122                // Add to vocabulary (will handle max_size internally)
123                self.vocabulary.add_token(&token);
124            }
125        }
126
127        // Filter by document frequency
128        let min_count = (self.min_df * total_docs as f64).ceil() as usize;
129        let max_count = (self.max_df * total_docs as f64).floor() as usize;
130
131        let mut filtered_tokens: Vec<(String, usize)> = doc_frequencies
132            .into_iter()
133            .filter(|(_, count)| *count >= min_count && *count <= max_count)
134            .collect();
135
136        // Sort by frequency and limit features if needed
137        filtered_tokens.sort_by(|a, b| b.1.cmp(&a.1));
138
139        if let Some(max_features) = self.max_features {
140            filtered_tokens.truncate(max_features);
141        }
142
143        // Rebuild vocabulary with filtered tokens
144        self.vocabulary = Vocabulary::with_maxsize(self.max_features.unwrap_or(usize::MAX));
145        for (token, _) in filtered_tokens {
146            self.vocabulary.add_token(&token);
147        }
148
149        Ok(())
150    }
151
152    /// Extract n-grams from text based on the configured range
153    fn extract_ngrams(&self, text: &str) -> Result<Vec<String>> {
154        let text = if self.lowercase {
155            text.to_lowercase()
156        } else {
157            text.to_string()
158        };
159
160        // If range is (1, 1), just use word tokenizer
161        let all_ngrams = if self.ngram_range == (1, 1) {
162            let tokenizer = WordTokenizer::new(false);
163            tokenizer.tokenize(&text)?
164        } else {
165            // Use n-gram tokenizer for the range
166            let ngram_tokenizer =
167                NgramTokenizer::with_range(self.ngram_range.0, self.ngram_range.1)?;
168            ngram_tokenizer.tokenize(&text)?
169        };
170
171        Ok(all_ngrams)
172    }
173
174    /// Transform text into a count vector
175    pub fn transform(&self, text: &str) -> Result<Array1<f64>> {
176        if self.vocabulary.is_empty() {
177            return Err(TextError::VocabularyError(
178                "Vocabulary is empty. Call fit() first".to_string(),
179            ));
180        }
181
182        let vocab_size = self.vocabulary.len();
183        let mut vector = Array1::zeros(vocab_size);
184
185        // Extract n-grams
186        let tokens = self.extract_ngrams(text)?;
187
188        // Count tokens
189        for token in tokens {
190            if let Some(idx) = self.vocabulary.get_index(&token) {
191                vector[idx] += 1.0;
192            }
193        }
194
195        // Make binary if requested
196        if self.binary {
197            for val in vector.iter_mut() {
198                if *val > 0.0 {
199                    *val = 1.0;
200                }
201            }
202        }
203
204        Ok(vector)
205    }
206
207    /// Transform multiple texts into a count matrix
208    pub fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
209        if self.vocabulary.is_empty() {
210            return Err(TextError::VocabularyError(
211                "Vocabulary is empty. Call fit() first".to_string(),
212            ));
213        }
214
215        let n_samples = texts.len();
216        let vocab_size = self.vocabulary.len();
217        let mut matrix = Array2::zeros((n_samples, vocab_size));
218
219        for (i, text) in texts.iter().enumerate() {
220            let vector = self.transform(text)?;
221            matrix.row_mut(i).assign(&vector);
222        }
223
224        Ok(matrix)
225    }
226
227    /// Fit and transform in one step
228    pub fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
229        self.fit(texts)?;
230        self.transform_batch(texts)
231    }
232}
233
234impl Default for EnhancedCountVectorizer {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240/// Enhanced TF-IDF vectorizer with IDF smoothing options
241pub struct EnhancedTfidfVectorizer {
242    count_vectorizer: EnhancedCountVectorizer,
243    useidf: bool,
244    smoothidf: bool,
245    sublinear_tf: bool,
246    norm: Option<String>,
247    idf_: Option<Array1<f64>>,
248}
249
250impl EnhancedTfidfVectorizer {
251    /// Create a new enhanced TF-IDF vectorizer
252    pub fn new() -> Self {
253        Self {
254            count_vectorizer: EnhancedCountVectorizer::new(),
255            useidf: true,
256            smoothidf: true,
257            sublinear_tf: false,
258            norm: Some("l2".to_string()),
259            idf_: None,
260        }
261    }
262
263    /// Set whether to use IDF weighting
264    pub fn set_use_idf(mut self, useidf: bool) -> Self {
265        self.useidf = useidf;
266        self
267    }
268
269    /// Set whether to smooth IDF weights
270    pub fn set_smooth_idf(mut self, smoothidf: bool) -> Self {
271        self.smoothidf = smoothidf;
272        self
273    }
274
275    /// Set whether to use sublinear TF scaling
276    pub fn set_sublinear_tf(mut self, sublineartf: bool) -> Self {
277        self.sublinear_tf = sublineartf;
278        self
279    }
280
281    /// Set the normalization method (None, "l1", or "l2")
282    pub fn set_norm(mut self, norm: Option<String>) -> Result<Self> {
283        if let Some(ref n) = norm {
284            if n != "l1" && n != "l2" {
285                return Err(TextError::InvalidInput(
286                    "Norm must be 'l1', 'l2', or None".to_string(),
287                ));
288            }
289        }
290        self.norm = norm;
291        Ok(self)
292    }
293
294    /// Set n-gram range
295    pub fn set_ngram_range(mut self, range: (usize, usize)) -> Result<Self> {
296        self.count_vectorizer = self.count_vectorizer.set_ngram_range(range)?;
297        Ok(self)
298    }
299
300    /// Set maximum features
301    pub fn set_max_features(mut self, maxfeatures: Option<usize>) -> Self {
302        self.count_vectorizer = self.count_vectorizer.set_max_features(maxfeatures);
303        self
304    }
305
306    /// Get the vocabulary
307    pub fn vocabulary(&self) -> &Vocabulary {
308        self.count_vectorizer.vocabulary()
309    }
310
311    /// Fit the vectorizer on a corpus
312    pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
313        // Fit the count vectorizer
314        self.count_vectorizer.fit(texts)?;
315
316        if self.useidf {
317            // Calculate IDF weights
318            self.calculate_idf(texts)?;
319        }
320
321        Ok(())
322    }
323
324    /// Calculate IDF weights
325    fn calculate_idf(&mut self, texts: &[&str]) -> Result<()> {
326        let vocab_size = self.count_vectorizer.vocabulary().len();
327        let mut df: Array1<f64> = Array1::zeros(vocab_size);
328        let n_samples = texts.len() as f64;
329
330        // Count document frequencies
331        for text in texts {
332            let count_vec = self.count_vectorizer.transform(text)?;
333            for (idx, &count) in count_vec.iter().enumerate() {
334                if count > 0.0 {
335                    df[idx] += 1.0;
336                }
337            }
338        }
339
340        // Calculate IDF
341        let mut idf = Array1::zeros(vocab_size);
342        for (idx, &doc_freq) in df.iter().enumerate() {
343            if self.smoothidf {
344                idf[idx] = (1.0 + n_samples) / (1.0 + doc_freq);
345            } else {
346                idf[idx] = n_samples / doc_freq.max(1.0);
347            }
348            idf[idx] = idf[idx].ln() + 1.0;
349        }
350
351        self.idf_ = Some(idf);
352        Ok(())
353    }
354
355    /// Transform text into a TF-IDF vector
356    pub fn transform(&self, text: &str) -> Result<Array1<f64>> {
357        // Get count vector
358        let mut vector = self.count_vectorizer.transform(text)?;
359
360        // Apply sublinear TF scaling if requested
361        if self.sublinear_tf {
362            for val in vector.iter_mut() {
363                if *val > 0.0 {
364                    *val = 1.0 + (*val).ln();
365                }
366            }
367        }
368
369        // Apply IDF weighting
370        if self.useidf {
371            if let Some(ref idf) = self.idf_ {
372                vector *= idf;
373            } else {
374                return Err(TextError::VocabularyError(
375                    "IDF weights not calculated. Call fit() first".to_string(),
376                ));
377            }
378        }
379
380        // Apply normalization
381        if let Some(ref norm) = self.norm {
382            match norm.as_str() {
383                "l1" => {
384                    let norm_val = vector.iter().map(|x| x.abs()).sum::<f64>();
385                    if norm_val > 0.0 {
386                        vector /= norm_val;
387                    }
388                }
389                "l2" => {
390                    let norm_val = vector.dot(&vector).sqrt();
391                    if norm_val > 0.0 {
392                        vector /= norm_val;
393                    }
394                }
395                _ => {}
396            }
397        }
398
399        Ok(vector)
400    }
401
402    /// Transform multiple texts into a TF-IDF matrix
403    pub fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
404        let n_samples = texts.len();
405        let vocab_size = self.count_vectorizer.vocabulary().len();
406        let mut matrix = Array2::zeros((n_samples, vocab_size));
407
408        for (i, text) in texts.iter().enumerate() {
409            let vector = self.transform(text)?;
410            matrix.row_mut(i).assign(&vector);
411        }
412
413        Ok(matrix)
414    }
415
416    /// Fit and transform in one step
417    pub fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
418        self.fit(texts)?;
419        self.transform_batch(texts)
420    }
421}
422
423impl Default for EnhancedTfidfVectorizer {
424    fn default() -> Self {
425        Self::new()
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_enhanced_count_vectorizer_unigrams() {
435        let mut vectorizer = EnhancedCountVectorizer::new();
436
437        let documents = vec![
438            "this is a test",
439            "this is another test",
440            "something different here",
441        ];
442
443        vectorizer.fit(&documents).unwrap();
444
445        let vector = vectorizer.transform("this is a test").unwrap();
446        assert!(!vector.is_empty());
447    }
448
449    #[test]
450    fn test_enhanced_count_vectorizer_ngrams() {
451        let mut vectorizer = EnhancedCountVectorizer::new()
452            .set_ngram_range((1, 2))
453            .unwrap();
454
455        let documents = vec!["hello world", "hello there", "world peace"];
456
457        vectorizer.fit(&documents).unwrap();
458
459        // Should include both unigrams and bigrams
460        let vocab = vectorizer.vocabulary();
461        assert!(vocab.len() > 3); // More than just unigrams
462    }
463
464    #[test]
465    fn test_enhanced_tfidf_vectorizer() {
466        let mut vectorizer = EnhancedTfidfVectorizer::new()
467            .set_smooth_idf(true)
468            .set_norm(Some("l2".to_string()))
469            .unwrap();
470
471        let documents = vec![
472            "this is a test",
473            "this is another test",
474            "something different here",
475        ];
476
477        vectorizer.fit(&documents).unwrap();
478
479        let vector = vectorizer.transform("this is a test").unwrap();
480
481        // Check L2 normalization
482        let norm = vector.dot(&vector).sqrt();
483        assert!((norm - 1.0).abs() < 1e-6);
484    }
485
486    #[test]
487    fn test_max_features() {
488        let mut vectorizer = EnhancedCountVectorizer::new().set_max_features(Some(5));
489
490        let documents = vec![
491            "one two three four five six seven eight",
492            "one two three four five six seven eight nine ten",
493        ];
494
495        vectorizer.fit(&documents).unwrap();
496
497        // Should only keep top 5 features
498        assert_eq!(vectorizer.vocabulary().len(), 5);
499    }
500
501    #[test]
502    fn test_document_frequency_filtering() {
503        let mut vectorizer = EnhancedCountVectorizer::new().set_min_df(0.5).unwrap(); // Token must appear in at least 50% of docs
504
505        let documents = vec![
506            "common word rare",
507            "common word unique",
508            "common another distinct",
509        ];
510
511        vectorizer.fit(&documents).unwrap();
512
513        // Only "common" should remain (appears in all docs)
514        let vocab = vectorizer.vocabulary();
515        assert!(vocab.contains("common"));
516        assert!(!vocab.contains("rare"));
517        assert!(!vocab.contains("unique"));
518    }
519}