scirs2_transform/
text.rs

1//! Text processing transformers for feature extraction
2//!
3//! This module provides utilities for converting text data into numerical features
4//! suitable for machine learning algorithms.
5
6use ahash::AHasher;
7use regex::Regex;
8use scirs2_core::ndarray::{Array1, Array2};
9use std::collections::{HashMap, HashSet};
10use std::hash::{Hash, Hasher};
11
12use crate::error::{Result, TransformError};
13
14/// Count vectorizer for converting text documents to term frequency vectors
15pub struct CountVectorizer {
16    /// Vocabulary mapping from terms to indices
17    vocabulary: HashMap<String, usize>,
18    /// Inverse vocabulary mapping from indices to terms
19    feature_names: Vec<String>,
20    /// Maximum number of features
21    max_features: Option<usize>,
22    /// Minimum document frequency
23    min_df: f64,
24    /// Maximum document frequency
25    max_df: f64,
26    /// Whether to convert to lowercase
27    lowercase: bool,
28    /// Token pattern regex
29    token_pattern: Regex,
30    /// Set of stop words to exclude
31    stop_words: HashSet<String>,
32    /// Whether the vectorizer has been fitted
33    fitted: bool,
34}
35
36impl CountVectorizer {
37    /// Create a new count vectorizer
38    pub fn new() -> Self {
39        CountVectorizer {
40            vocabulary: HashMap::new(),
41            feature_names: Vec::new(),
42            max_features: None,
43            min_df: 1.0,
44            max_df: 1.0,
45            lowercase: true,
46            token_pattern: Regex::new(r"\b\w+\b").unwrap(),
47            stop_words: HashSet::new(),
48            fitted: false,
49        }
50    }
51
52    /// Set maximum number of features
53    #[allow(dead_code)]
54    pub fn with_max_features(mut self, maxfeatures: usize) -> Self {
55        self.max_features = Some(maxfeatures);
56        self
57    }
58
59    /// Set minimum document frequency
60    #[allow(dead_code)]
61    pub fn with_min_df(mut self, mindf: f64) -> Self {
62        self.min_df = mindf;
63        self
64    }
65
66    /// Set maximum document frequency
67    #[allow(dead_code)]
68    pub fn with_max_df(mut self, maxdf: f64) -> Self {
69        self.max_df = maxdf;
70        self
71    }
72
73    /// Set whether to convert to lowercase
74    #[allow(dead_code)]
75    pub fn with_lowercase(mut self, lowercase: bool) -> Self {
76        self.lowercase = lowercase;
77        self
78    }
79
80    /// Set custom token pattern
81    #[allow(dead_code)]
82    pub fn with_token_pattern(mut self, pattern: &str) -> Result<Self> {
83        self.token_pattern = Regex::new(pattern)
84            .map_err(|e| TransformError::InvalidInput(format!("Invalid regex pattern: {e}")))?;
85        Ok(self)
86    }
87
88    /// Set stop words
89    #[allow(dead_code)]
90    pub fn with_stop_words(mut self, stopwords: Vec<String>) -> Self {
91        self.stop_words = stopwords.into_iter().collect();
92        self
93    }
94
95    /// Tokenize a document
96    fn tokenize(&self, doc: &str) -> Vec<String> {
97        let text = if self.lowercase {
98            doc.to_lowercase()
99        } else {
100            doc.to_string()
101        };
102
103        self.token_pattern
104            .find_iter(&text)
105            .map(|m| m.as_str().to_string())
106            .filter(|token| !self.stop_words.contains(token))
107            .collect()
108    }
109
110    /// Fit the vectorizer on a collection of documents
111    pub fn fit(&mut self, documents: &[String]) -> Result<()> {
112        if documents.is_empty() {
113            return Err(TransformError::InvalidInput(
114                "Empty document collection".into(),
115            ));
116        }
117
118        // Count term frequencies across all documents
119        let mut term_doc_freq: HashMap<String, usize> = HashMap::new();
120        let n_docs = documents.len();
121
122        for doc in documents {
123            let tokens: HashSet<String> = self.tokenize(doc).into_iter().collect();
124            for token in tokens {
125                *term_doc_freq.entry(token).or_insert(0) += 1;
126            }
127        }
128
129        // Filter by document frequency
130        let min_df_count = if self.min_df <= 1.0 {
131            self.min_df as usize
132        } else {
133            (self.min_df * n_docs as f64).ceil() as usize
134        };
135
136        let max_df_count = if self.max_df <= 1.0 {
137            (self.max_df * n_docs as f64).floor() as usize
138        } else {
139            self.max_df as usize
140        };
141
142        let mut filtered_terms: Vec<(String, usize)> = term_doc_freq
143            .into_iter()
144            .filter(|(_, freq)| *freq >= min_df_count && *freq <= max_df_count)
145            .collect();
146
147        // Sort by document frequency (descending) for max_features selection
148        filtered_terms.sort_by(|a, b| b.1.cmp(&a.1));
149
150        // Limit to max_features if specified
151        if let Some(max_feat) = self.max_features {
152            filtered_terms.truncate(max_feat);
153        }
154
155        // Build vocabulary
156        self.vocabulary.clear();
157        self.feature_names.clear();
158
159        for (idx, (term, _freq)) in filtered_terms.into_iter().enumerate() {
160            self.vocabulary.insert(term.clone(), idx);
161            self.feature_names.push(term);
162        }
163
164        self.fitted = true;
165        Ok(())
166    }
167
168    /// Transform documents to count vectors
169    pub fn transform(&self, documents: &[String]) -> Result<Array2<f64>> {
170        if !self.fitted {
171            return Err(TransformError::NotFitted(
172                "CountVectorizer must be fitted before transform".into(),
173            ));
174        }
175
176        let n_samples = documents.len();
177        let n_features = self.vocabulary.len();
178        let mut result = Array2::zeros((n_samples, n_features));
179
180        for (i, doc) in documents.iter().enumerate() {
181            let tokens = self.tokenize(doc);
182            for token in tokens {
183                if let Some(&idx) = self.vocabulary.get(&token) {
184                    result[[i, idx]] += 1.0;
185                }
186            }
187        }
188
189        Ok(result)
190    }
191
192    /// Fit and transform in one step
193    pub fn fit_transform(&mut self, documents: &[String]) -> Result<Array2<f64>> {
194        self.fit(documents)?;
195        self.transform(documents)
196    }
197
198    /// Get feature names
199    #[allow(dead_code)]
200    pub fn get_feature_names(&self) -> &[String] {
201        &self.feature_names
202    }
203}
204
205impl Default for CountVectorizer {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211/// TF-IDF vectorizer for converting text to TF-IDF features
212pub struct TfidfVectorizer {
213    /// Underlying count vectorizer
214    count_vectorizer: CountVectorizer,
215    /// IDF values for each feature
216    idf: Array1<f64>,
217    /// Whether to use IDF weighting
218    use_idf: bool,
219    /// Whether to apply L2 normalization
220    norm: bool,
221    /// Whether to add 1 to document frequencies
222    smooth_idf: bool,
223    /// Whether to subtract 1 from IDF
224    sublinear_tf: bool,
225}
226
227impl TfidfVectorizer {
228    /// Create a new TF-IDF vectorizer
229    pub fn new() -> Self {
230        TfidfVectorizer {
231            count_vectorizer: CountVectorizer::new(),
232            idf: Array1::zeros(0),
233            use_idf: true,
234            norm: true,
235            smooth_idf: true,
236            sublinear_tf: false,
237        }
238    }
239
240    /// Set whether to use IDF weighting
241    #[allow(dead_code)]
242    pub fn with_use_idf(mut self, useidf: bool) -> Self {
243        self.use_idf = useidf;
244        self
245    }
246
247    /// Set whether to apply L2 normalization
248    #[allow(dead_code)]
249    pub fn with_norm(mut self, norm: bool) -> Self {
250        self.norm = norm;
251        self
252    }
253
254    /// Set whether to smooth IDF weights
255    #[allow(dead_code)]
256    pub fn with_smooth_idf(mut self, smoothidf: bool) -> Self {
257        self.smooth_idf = smoothidf;
258        self
259    }
260
261    /// Set whether to use sublinear term frequency
262    #[allow(dead_code)]
263    pub fn with_sublinear_tf(mut self, sublineartf: bool) -> Self {
264        self.sublinear_tf = sublineartf;
265        self
266    }
267
268    /// Configure the underlying count vectorizer
269    #[allow(dead_code)]
270    pub fn configure_count_vectorizer<F>(mut self, f: F) -> Self
271    where
272        F: FnOnce(CountVectorizer) -> CountVectorizer,
273    {
274        self.count_vectorizer = f(self.count_vectorizer);
275        self
276    }
277
278    /// Fit the vectorizer on documents
279    pub fn fit(&mut self, documents: &[String]) -> Result<()> {
280        // Fit count vectorizer
281        self.count_vectorizer.fit(documents)?;
282
283        if self.use_idf {
284            // Calculate IDF values
285            let n_samples = documents.len() as f64;
286            let n_features = self.count_vectorizer.vocabulary.len();
287            let mut df = Array1::zeros(n_features);
288
289            // Count document frequencies
290            for doc in documents {
291                let tokens: HashSet<String> =
292                    self.count_vectorizer.tokenize(doc).into_iter().collect();
293                for token in tokens {
294                    if let Some(&idx) = self.count_vectorizer.vocabulary.get(&token) {
295                        df[idx] += 1.0;
296                    }
297                }
298            }
299
300            // Calculate IDF
301            if self.smooth_idf {
302                self.idf = df.mapv(|d: f64| ((n_samples + 1.0) / (d + 1.0)).ln() + 1.0);
303            } else {
304                self.idf = df.mapv(|d: f64| (n_samples / d).ln() + 1.0);
305            }
306        }
307
308        Ok(())
309    }
310
311    /// Transform documents to TF-IDF vectors
312    pub fn transform(&self, documents: &[String]) -> Result<Array2<f64>> {
313        // Get count vectors
314        let mut x = self.count_vectorizer.transform(documents)?;
315
316        // Apply sublinear TF scaling
317        if self.sublinear_tf {
318            x.mapv_inplace(|v| if v > 0.0 { 1.0 + v.ln() } else { 0.0 });
319        }
320
321        // Apply IDF weighting
322        if self.use_idf {
323            for i in 0..x.shape()[0] {
324                for j in 0..x.shape()[1] {
325                    x[[i, j]] *= self.idf[j];
326                }
327            }
328        }
329
330        // Apply L2 normalization
331        if self.norm {
332            for i in 0..x.shape()[0] {
333                let row = x.row(i);
334                let norm = row.dot(&row).sqrt();
335                if norm > 0.0 {
336                    x.row_mut(i).mapv_inplace(|v| v / norm);
337                }
338            }
339        }
340
341        Ok(x)
342    }
343
344    /// Fit and transform in one step
345    pub fn fit_transform(&mut self, documents: &[String]) -> Result<Array2<f64>> {
346        self.fit(documents)?;
347        self.transform(documents)
348    }
349
350    /// Get feature names
351    #[allow(dead_code)]
352    pub fn get_feature_names(&self) -> &[String] {
353        self.count_vectorizer.get_feature_names()
354    }
355}
356
357impl Default for TfidfVectorizer {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363/// Hashing vectorizer for memory-efficient text vectorization
364pub struct HashingVectorizer {
365    /// Number of features (hash space size)
366    n_features: usize,
367    /// Whether to convert to lowercase
368    lowercase: bool,
369    /// Token pattern regex
370    token_pattern: Regex,
371    /// Whether to use binary occurrence instead of counts
372    binary: bool,
373    /// Norm to use for normalization
374    norm: Option<String>,
375}
376
377impl HashingVectorizer {
378    /// Create a new hashing vectorizer
379    pub fn new(_nfeatures: usize) -> Self {
380        HashingVectorizer {
381            n_features: _nfeatures,
382            lowercase: true,
383            token_pattern: Regex::new(r"\b\w+\b").unwrap(),
384            binary: false,
385            norm: Some("l2".to_string()),
386        }
387    }
388
389    /// Set whether to use binary occurrence
390    #[allow(dead_code)]
391    pub fn with_binary(mut self, binary: bool) -> Self {
392        self.binary = binary;
393        self
394    }
395
396    /// Set normalization method
397    #[allow(dead_code)]
398    pub fn with_norm(mut self, norm: Option<String>) -> Self {
399        self.norm = norm;
400        self
401    }
402
403    /// Set whether to convert to lowercase
404    #[allow(dead_code)]
405    pub fn with_lowercase(mut self, lowercase: bool) -> Self {
406        self.lowercase = lowercase;
407        self
408    }
409
410    /// Hash a token to a feature index
411    fn hash_token(&self, token: &str) -> usize {
412        let mut hasher = AHasher::default();
413        token.hash(&mut hasher);
414        (hasher.finish() as usize) % self.n_features
415    }
416
417    /// Tokenize a document
418    fn tokenize(&self, doc: &str) -> Vec<String> {
419        let text = if self.lowercase {
420            doc.to_lowercase()
421        } else {
422            doc.to_string()
423        };
424
425        self.token_pattern
426            .find_iter(&text)
427            .map(|m| m.as_str().to_string())
428            .collect()
429    }
430
431    /// Transform documents to hashed feature vectors
432    pub fn transform(&self, documents: &[String]) -> Result<Array2<f64>> {
433        let n_samples = documents.len();
434        let mut result = Array2::zeros((n_samples, self.n_features));
435
436        for (i, doc) in documents.iter().enumerate() {
437            let tokens = self.tokenize(doc);
438
439            if self.binary {
440                let unique_indices: HashSet<usize> =
441                    tokens.iter().map(|token| self.hash_token(token)).collect();
442
443                for idx in unique_indices {
444                    result[[i, idx]] = 1.0;
445                }
446            } else {
447                for token in tokens {
448                    let idx = self.hash_token(&token);
449                    result[[i, idx]] += 1.0;
450                }
451            }
452
453            // Apply normalization
454            if let Some(ref norm_type) = self.norm {
455                let row = result.row(i).to_owned();
456                let norm_value = match norm_type.as_str() {
457                    "l1" => row.iter().map(|v: &f64| v.abs()).sum::<f64>(),
458                    "l2" => row.dot(&row).sqrt(),
459                    _ => continue,
460                };
461
462                if norm_value > 0.0 {
463                    result.row_mut(i).mapv_inplace(|v| v / norm_value);
464                }
465            }
466        }
467
468        Ok(result)
469    }
470}
471
472/// Streaming count vectorizer that can learn vocabulary incrementally
473pub struct StreamingCountVectorizer {
474    /// Current vocabulary
475    vocabulary: HashMap<String, usize>,
476    /// Document frequency counts
477    doc_freq: HashMap<String, usize>,
478    /// Number of documents seen
479    n_docs_seen: usize,
480    /// Maximum vocabulary size
481    max_features: Option<usize>,
482    /// Whether to convert to lowercase
483    lowercase: bool,
484    /// Token pattern regex
485    token_pattern: Regex,
486}
487
488impl StreamingCountVectorizer {
489    /// Create a new streaming count vectorizer
490    pub fn new() -> Self {
491        StreamingCountVectorizer {
492            vocabulary: HashMap::new(),
493            doc_freq: HashMap::new(),
494            n_docs_seen: 0,
495            max_features: None,
496            lowercase: true,
497            token_pattern: Regex::new(r"\b\w+\b").unwrap(),
498        }
499    }
500
501    /// Set maximum vocabulary size
502    #[allow(dead_code)]
503    pub fn with_max_features(mut self, maxfeatures: usize) -> Self {
504        self.max_features = Some(maxfeatures);
505        self
506    }
507
508    /// Tokenize a document
509    fn tokenize(&self, doc: &str) -> Vec<String> {
510        let text = if self.lowercase {
511            doc.to_lowercase()
512        } else {
513            doc.to_string()
514        };
515
516        self.token_pattern
517            .find_iter(&text)
518            .map(|m| m.as_str().to_string())
519            .collect()
520    }
521
522    /// Update vocabulary with new documents
523    pub fn partial_fit(&mut self, documents: &[String]) -> Result<()> {
524        for doc in documents {
525            self.n_docs_seen += 1;
526            let tokens: HashSet<String> = self.tokenize(doc).into_iter().collect();
527
528            for token in tokens {
529                *self.doc_freq.entry(token.clone()).or_insert(0) += 1;
530
531                if !self.vocabulary.contains_key(&token) {
532                    if let Some(max_feat) = self.max_features {
533                        if self.vocabulary.len() >= max_feat {
534                            // Find least frequent term to replace
535                            if let Some((min_token_, _)) = self
536                                .vocabulary
537                                .iter()
538                                .min_by_key(|(t, _)| self.doc_freq.get(*t).unwrap_or(&0))
539                            {
540                                let min_token = min_token_.clone();
541                                let min_freq = self.doc_freq.get(&min_token).unwrap_or(&0);
542                                let new_freq = self.doc_freq.get(&token).unwrap_or(&0);
543
544                                if new_freq > min_freq {
545                                    let old_idx = self.vocabulary.remove(&min_token).unwrap();
546                                    self.vocabulary.insert(token, old_idx);
547                                }
548                            }
549                        } else {
550                            self.vocabulary.insert(token, self.vocabulary.len());
551                        }
552                    } else {
553                        self.vocabulary.insert(token, self.vocabulary.len());
554                    }
555                }
556            }
557        }
558
559        Ok(())
560    }
561
562    /// Transform documents using current vocabulary
563    pub fn transform(&self, documents: &[String]) -> Result<Array2<f64>> {
564        let n_samples = documents.len();
565        let n_features = self.vocabulary.len();
566
567        if n_features == 0 {
568            return Err(TransformError::NotFitted(
569                "No vocabulary learned yet".into(),
570            ));
571        }
572
573        let mut result = Array2::zeros((n_samples, n_features));
574
575        for (i, doc) in documents.iter().enumerate() {
576            let tokens = self.tokenize(doc);
577            for token in tokens {
578                if let Some(&idx) = self.vocabulary.get(&token) {
579                    result[[i, idx]] += 1.0;
580                }
581            }
582        }
583
584        Ok(result)
585    }
586}
587
588impl Default for StreamingCountVectorizer {
589    fn default() -> Self {
590        Self::new()
591    }
592}