scirs2_text/
vectorize.rs

1//! Text vectorization utilities
2//!
3//! This module provides functionality for converting text into
4//! numerical vector representations.
5
6use crate::error::{Result, TextError};
7use crate::tokenize::{Tokenizer, WordTokenizer};
8use crate::vocabulary::Vocabulary;
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use scirs2_core::parallel_ops;
11use std::collections::HashMap;
12
13/// Trait for text vectorizers
14pub trait Vectorizer: Clone {
15    /// Fit the vectorizer on a corpus of texts
16    fn fit(&mut self, texts: &[&str]) -> Result<()>;
17
18    /// Transform a text into a vector
19    fn transform(&self, text: &str) -> Result<Array1<f64>>;
20
21    /// Transform a batch of texts into a matrix where each row is a document vector
22    fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>>;
23
24    /// Fit on a corpus and then transform a batch of texts
25    fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
26        self.fit(texts)?;
27        self.transform_batch(texts)
28    }
29}
30
31/// Count vectorizer that uses a bag-of-words representation
32pub struct CountVectorizer {
33    tokenizer: Box<dyn Tokenizer + Send + Sync>,
34    vocabulary: Vocabulary,
35    binary: bool, // If true, all non-zero counts are set to 1
36}
37
38impl Clone for CountVectorizer {
39    fn clone(&self) -> Self {
40        Self {
41            tokenizer: self.tokenizer.clone_box(),
42            vocabulary: self.vocabulary.clone(),
43            binary: self.binary,
44        }
45    }
46}
47
48impl CountVectorizer {
49    /// Create a new count vectorizer
50    pub fn new(binary: bool) -> Self {
51        Self {
52            tokenizer: Box::new(WordTokenizer::default()),
53            vocabulary: Vocabulary::new(),
54            binary,
55        }
56    }
57
58    /// Create a count vectorizer with a custom tokenizer
59    pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>, binary: bool) -> Self {
60        Self {
61            tokenizer,
62            vocabulary: Vocabulary::new(),
63            binary,
64        }
65    }
66
67    /// Get a reference to the vocabulary
68    pub fn vocabulary(&self) -> &Vocabulary {
69        &self.vocabulary
70    }
71
72    /// Get the vocabulary size
73    pub fn vocabulary_size(&self) -> usize {
74        self.vocabulary.len()
75    }
76
77    /// Get feature count for a specific document and feature index from a matrix
78    pub fn get_feature_count(
79        &self,
80        matrix: &Array2<f64>,
81        document_index: usize,
82        feature_index: usize,
83    ) -> Option<f64> {
84        if document_index < matrix.nrows() && feature_index < matrix.ncols() {
85            Some(matrix[[document_index, feature_index]])
86        } else {
87            None
88        }
89    }
90
91    /// Get vocabulary as HashMap for compatibility with visualization
92    pub fn vocabulary_map(&self) -> HashMap<String, usize> {
93        self.vocabulary.token_to_index().clone()
94    }
95}
96
97impl Default for CountVectorizer {
98    fn default() -> Self {
99        Self::new(false)
100    }
101}
102
103impl Vectorizer for CountVectorizer {
104    fn fit(&mut self, texts: &[&str]) -> Result<()> {
105        if texts.is_empty() {
106            return Err(TextError::InvalidInput(
107                "No texts provided for fitting".into(),
108            ));
109        }
110
111        // Clear any existing vocabulary
112        self.vocabulary = Vocabulary::new();
113
114        // Process all documents to build vocabulary
115        for &text in texts {
116            let tokens = self.tokenizer.tokenize(text)?;
117            for token in tokens {
118                self.vocabulary.add_token(&token);
119            }
120        }
121
122        Ok(())
123    }
124
125    fn transform(&self, text: &str) -> Result<Array1<f64>> {
126        if self.vocabulary.is_empty() {
127            return Err(TextError::VocabularyError(
128                "Vocabulary is empty. Call fit() first".into(),
129            ));
130        }
131
132        let vocab_size = self.vocabulary.len();
133        let mut vector = Array1::zeros(vocab_size);
134
135        // Tokenize the text
136        let tokens = self.tokenizer.tokenize(text)?;
137
138        // Count tokens
139        for token in tokens {
140            if let Some(idx) = self.vocabulary.get_index(&token) {
141                vector[idx] += 1.0;
142            }
143        }
144
145        // Make binary if requested
146        if self.binary {
147            for val in vector.iter_mut() {
148                if *val > 0.0 {
149                    *val = 1.0;
150                }
151            }
152        }
153
154        Ok(vector)
155    }
156
157    fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
158        if self.vocabulary.is_empty() {
159            return Err(TextError::VocabularyError(
160                "Vocabulary is empty. Call fit() first".into(),
161            ));
162        }
163
164        if texts.is_empty() {
165            return Ok(Array2::zeros((0, self.vocabulary.len())));
166        }
167
168        // Use scirs2-core::parallel for parallel processing
169        // Clone data to avoid lifetime issues
170        let texts_owned: Vec<String> = texts.iter().map(|&s| s.to_string()).collect();
171        let self_clone = self.clone();
172
173        let vectors = parallel_ops::parallel_map_result(&texts_owned, move |text| {
174            self_clone.transform(text).map_err(|e| {
175                // Convert TextError to CoreError
176                scirs2_core::CoreError::ComputationError(scirs2_core::error::ErrorContext::new(
177                    format!("Text vectorization error: {e}"),
178                ))
179            })
180        })?;
181
182        // Convert to 2D array
183        let n_samples = vectors.len();
184        let n_features = self.vocabulary.len();
185
186        let mut matrix = Array2::zeros((n_samples, n_features));
187        for (i, vec) in vectors.iter().enumerate() {
188            matrix.row_mut(i).assign(vec);
189        }
190
191        Ok(matrix)
192    }
193}
194
195/// TF-IDF vectorizer that computes term frequency-inverse document frequency
196#[derive(Clone)]
197pub struct TfidfVectorizer {
198    count_vectorizer: CountVectorizer,
199    idf: Option<Array1<f64>>,
200    smoothidf: bool,
201    norm: Option<String>, // None, "l1", "l2"
202}
203
204impl TfidfVectorizer {
205    /// Create a new TF-IDF vectorizer
206    pub fn new(binary: bool, smoothidf: bool, norm: Option<String>) -> Self {
207        Self {
208            count_vectorizer: CountVectorizer::new(binary),
209            idf: None,
210            smoothidf,
211            norm,
212        }
213    }
214
215    /// Create a TF-IDF vectorizer with a custom tokenizer
216    pub fn with_tokenizer(
217        tokenizer: Box<dyn Tokenizer + Send + Sync>,
218        binary: bool,
219        smoothidf: bool,
220        norm: Option<String>,
221    ) -> Self {
222        Self {
223            count_vectorizer: CountVectorizer::with_tokenizer(tokenizer, binary),
224            idf: None,
225            smoothidf,
226            norm,
227        }
228    }
229
230    /// Get a reference to the vocabulary
231    pub fn vocabulary(&self) -> &Vocabulary {
232        self.count_vectorizer.vocabulary()
233    }
234
235    /// Get the vocabulary size
236    pub fn vocabulary_size(&self) -> usize {
237        self.count_vectorizer.vocabulary_size()
238    }
239
240    /// Get TF-IDF score for a specific document and feature index from a matrix
241    pub fn get_feature_score(
242        &self,
243        matrix: &Array2<f64>,
244        document_index: usize,
245        feature_index: usize,
246    ) -> Option<f64> {
247        if document_index < matrix.nrows() && feature_index < matrix.ncols() {
248            Some(matrix[[document_index, feature_index]])
249        } else {
250            None
251        }
252    }
253
254    /// Get vocabulary as HashMap for compatibility with visualization
255    pub fn vocabulary_map(&self) -> HashMap<String, usize> {
256        self.count_vectorizer.vocabulary_map()
257    }
258
259    /// Compute IDF values from document frequencies
260    fn compute_idf(&mut self, df: &Array1<f64>, ndocuments: f64) -> Result<()> {
261        let n_features = df.len();
262
263        let mut idf = Array1::zeros(n_features);
264
265        for (i, &df_i) in df.iter().enumerate() {
266            if df_i > 0.0 {
267                if self.smoothidf {
268                    // log((ndocuments + 1) / (df + 1)) + 1
269                    idf[i] = ((ndocuments + 1.0) / (df_i + 1.0)).ln() + 1.0;
270                } else {
271                    // log(ndocuments / df)
272                    idf[i] = (ndocuments / df_i).ln();
273                }
274            } else if self.smoothidf {
275                idf[i] = ((ndocuments + 1.0) / 1.0).ln() + 1.0;
276            } else {
277                // For features that aren't present in the corpus, set IDF to a high value
278                idf[i] = 0.0;
279            }
280        }
281
282        self.idf = Some(idf);
283        Ok(())
284    }
285
286    /// Apply normalization to a document vector
287    fn normalize_vector(&self, vector: &mut Array1<f64>) -> Result<()> {
288        if let Some(ref norm) = self.norm {
289            match norm.as_str() {
290                "l1" => {
291                    let sum = vector.sum();
292                    if sum > 0.0 {
293                        vector.mapv_inplace(|x| x / sum);
294                    }
295                }
296                "l2" => {
297                    let squared_sum: f64 = vector.iter().map(|&x| x * x).sum();
298                    if squared_sum > 0.0 {
299                        let norm = squared_sum.sqrt();
300                        vector.mapv_inplace(|x| x / norm);
301                    }
302                }
303                _ => {
304                    return Err(TextError::InvalidInput(format!(
305                        "Unknown normalization: {norm}"
306                    )))
307                }
308            }
309        }
310
311        Ok(())
312    }
313}
314
315impl Default for TfidfVectorizer {
316    fn default() -> Self {
317        Self::new(false, true, Some("l2".to_string()))
318    }
319}
320
321impl Vectorizer for TfidfVectorizer {
322    fn fit(&mut self, texts: &[&str]) -> Result<()> {
323        if texts.is_empty() {
324            return Err(TextError::InvalidInput(
325                "No texts provided for fitting".into(),
326            ));
327        }
328
329        // First, fit the count vectorizer to build the vocabulary
330        self.count_vectorizer.fit(texts)?;
331
332        let ndocuments = texts.len() as f64;
333        let n_features = self.count_vectorizer.vocabulary_size();
334
335        // Get document frequency for each term
336        let mut df = Array1::zeros(n_features);
337
338        for &text in texts {
339            let tokens = self.count_vectorizer.tokenizer.tokenize(text)?;
340            let mut seen_tokens = HashMap::new();
341
342            // Count each token only once per document
343            for token in tokens {
344                if let Some(idx) = self.count_vectorizer.vocabulary.get_index(&token) {
345                    seen_tokens.insert(idx, true);
346                }
347            }
348
349            // Update document frequencies
350            for idx in seen_tokens.keys() {
351                df[*idx] += 1.0;
352            }
353        }
354
355        // Compute IDF
356        self.compute_idf(&df, ndocuments)?;
357
358        Ok(())
359    }
360
361    fn transform(&self, text: &str) -> Result<Array1<f64>> {
362        if self.idf.is_none() {
363            return Err(TextError::VocabularyError(
364                "IDF values not computed. Call fit() first".into(),
365            ));
366        }
367
368        // Get count vector
369        let mut count_vector = self.count_vectorizer.transform(text)?;
370
371        // Apply TF-IDF transformation
372        let idf = self.idf.as_ref().unwrap();
373        for i in 0..count_vector.len() {
374            count_vector[i] *= idf[i];
375        }
376
377        // Apply normalization if requested
378        self.normalize_vector(&mut count_vector)?;
379
380        Ok(count_vector)
381    }
382
383    fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
384        if self.idf.is_none() {
385            return Err(TextError::VocabularyError(
386                "IDF values not computed. Call fit() first".into(),
387            ));
388        }
389
390        if texts.is_empty() {
391            return Ok(Array2::zeros((0, self.count_vectorizer.vocabulary_size())));
392        }
393
394        // Get count vectors
395        let mut count_matrix = self.count_vectorizer.transform_batch(texts)?;
396
397        // Apply TF-IDF transformation
398        let idf = self.idf.as_ref().unwrap();
399        for mut row in count_matrix.axis_iter_mut(Axis(0)) {
400            for i in 0..row.len() {
401                row[i] *= idf[i];
402            }
403
404            // Apply normalization if requested
405            if let Some(ref norm) = self.norm {
406                match norm.as_str() {
407                    "l1" => {
408                        let sum = row.sum();
409                        if sum > 0.0 {
410                            row.mapv_inplace(|x| x / sum);
411                        }
412                    }
413                    "l2" => {
414                        let squared_sum: f64 = row.iter().map(|&x| x * x).sum();
415                        if squared_sum > 0.0 {
416                            let norm = squared_sum.sqrt();
417                            row.mapv_inplace(|x| x / norm);
418                        }
419                    }
420                    _ => {
421                        return Err(TextError::InvalidInput(format!(
422                            "Unknown normalization: {norm}"
423                        )))
424                    }
425                }
426            }
427        }
428
429        Ok(count_matrix)
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_count_vectorizer() {
439        let mut vectorizer = CountVectorizer::default();
440        let corpus = [
441            "This is the first document.",
442            "This document is the second document.",
443        ];
444
445        // Fit the vectorizer
446        vectorizer.fit(&corpus).unwrap();
447
448        // Check vocabulary
449        assert_eq!(vectorizer.vocabulary_size(), 6);
450
451        // Transform a document
452        let vec = vectorizer.transform(corpus[0]).unwrap();
453        assert_eq!(vec.len(), 6);
454
455        // Check that document frequencies are correct
456        let vec_sum: f64 = vec.iter().sum();
457        assert_eq!(vec_sum, 5.0); // 5 tokens in the first document
458    }
459
460    #[test]
461    fn test_tfidf_vectorizer() {
462        let mut vectorizer = TfidfVectorizer::default();
463        let corpus = [
464            "This is the first document.",
465            "This document is the second document.",
466        ];
467
468        // Fit the vectorizer
469        vectorizer.fit(&corpus).unwrap();
470
471        // Transform a document
472        let vec = vectorizer.transform(corpus[0]).unwrap();
473        assert_eq!(vec.len(), 6);
474
475        // Check that the vector is normalized (using L2 norm)
476        let norm: f64 = vec.iter().map(|&x| x * x).sum::<f64>().sqrt();
477        assert!((norm - 1.0).abs() < 1e-10);
478    }
479
480    #[test]
481    fn test_binary_vectorizer() {
482        let mut vectorizer = CountVectorizer::new(true);
483        let corpus = ["this this this is a document", "this is another document"];
484
485        // Fit and transform
486        let matrix = vectorizer.fit_transform(&corpus).unwrap();
487
488        // First document should have binary values (all 1.0 or 0.0)
489        for val in matrix.row(0).iter() {
490            assert!(*val == 0.0 || *val == 1.0);
491        }
492    }
493}