scirs2_text/
sparse_vectorize.rs

1//! Sparse vectorization for memory-efficient text representation
2//!
3//! This module provides sparse implementations of text vectorizers
4//! that use memory-efficient sparse matrix representations.
5
6use crate::error::{Result, TextError};
7use crate::sparse::{CsrMatrix, SparseMatrixBuilder, SparseVector};
8use crate::tokenize::{Tokenizer, WordTokenizer};
9use crate::vocabulary::Vocabulary;
10use scirs2_core::ndarray::Array1;
11use std::collections::HashMap;
12
13/// Sparse count vectorizer using CSR matrix representation
14pub struct SparseCountVectorizer {
15    tokenizer: Box<dyn Tokenizer + Send + Sync>,
16    vocabulary: Vocabulary,
17    binary: bool,
18}
19
20impl Clone for SparseCountVectorizer {
21    fn clone(&self) -> Self {
22        Self {
23            tokenizer: self.tokenizer.clone_box(),
24            vocabulary: self.vocabulary.clone(),
25            binary: self.binary,
26        }
27    }
28}
29
30impl SparseCountVectorizer {
31    /// Create a new sparse count vectorizer
32    pub fn new(binary: bool) -> Self {
33        Self {
34            tokenizer: Box::new(WordTokenizer::default()),
35            vocabulary: Vocabulary::new(),
36            binary,
37        }
38    }
39
40    /// Create with a custom tokenizer
41    pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>, binary: bool) -> Self {
42        Self {
43            tokenizer,
44            vocabulary: Vocabulary::new(),
45            binary,
46        }
47    }
48
49    /// Fit the vectorizer on a corpus
50    pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
51        if texts.is_empty() {
52            return Err(TextError::InvalidInput(
53                "No texts provided for fitting".into(),
54            ));
55        }
56
57        self.vocabulary = Vocabulary::new();
58
59        for &text in texts {
60            let tokens = self.tokenizer.tokenize(text)?;
61            for token in tokens {
62                self.vocabulary.add_token(&token);
63            }
64        }
65
66        Ok(())
67    }
68
69    /// Transform a single text into a sparse vector
70    pub fn transform(&self, text: &str) -> Result<SparseVector> {
71        let tokens = self.tokenizer.tokenize(text)?;
72        let mut counts: HashMap<usize, f64> = HashMap::new();
73
74        for token in tokens {
75            if let Some(idx) = self.vocabulary.get_index(&token) {
76                *counts.entry(idx).or_insert(0.0) += 1.0;
77            }
78        }
79
80        // Sort indices for efficient sparse operations
81        let mut indices: Vec<usize> = counts.keys().copied().collect();
82        indices.sort_unstable();
83
84        let values: Vec<f64> = if self.binary {
85            indices.iter().map(|_| 1.0).collect()
86        } else {
87            indices.iter().map(|&idx| counts[&idx]).collect()
88        };
89
90        let sparse_vec = SparseVector::fromindices_values(indices, values, self.vocabulary.len());
91
92        Ok(sparse_vec)
93    }
94
95    /// Transform a batch of texts into a sparse matrix
96    pub fn transform_batch(&self, texts: &[&str]) -> Result<CsrMatrix> {
97        let n_cols = self.vocabulary.len();
98        let mut builder = SparseMatrixBuilder::new(n_cols);
99
100        for &text in texts {
101            let sparse_vec = self.transform(text)?;
102            builder.add_row(sparse_vec)?;
103        }
104
105        Ok(builder.build())
106    }
107
108    /// Fit and transform in one step
109    pub fn fit_transform(&mut self, texts: &[&str]) -> Result<CsrMatrix> {
110        self.fit(texts)?;
111        self.transform_batch(texts)
112    }
113
114    /// Get vocabulary size
115    pub fn vocabulary_size(&self) -> usize {
116        self.vocabulary.len()
117    }
118
119    /// Get the vocabulary
120    pub fn vocabulary(&self) -> &Vocabulary {
121        &self.vocabulary
122    }
123}
124
125/// Sparse TF-IDF vectorizer
126#[derive(Clone)]
127pub struct SparseTfidfVectorizer {
128    count_vectorizer: SparseCountVectorizer,
129    idf: Option<Array1<f64>>,
130    useidf: bool,
131    norm: Option<String>,
132}
133
134impl SparseTfidfVectorizer {
135    /// Create a new sparse TF-IDF vectorizer
136    pub fn new() -> Self {
137        Self {
138            count_vectorizer: SparseCountVectorizer::new(false),
139            idf: None,
140            useidf: true,
141            norm: Some("l2".to_string()),
142        }
143    }
144
145    /// Create with custom settings
146    pub fn with_settings(useidf: bool, norm: Option<String>) -> Self {
147        Self {
148            count_vectorizer: SparseCountVectorizer::new(false),
149            idf: None,
150            useidf,
151            norm,
152        }
153    }
154
155    /// Create with a custom tokenizer
156    pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
157        Self {
158            count_vectorizer: SparseCountVectorizer::with_tokenizer(tokenizer, false),
159            idf: None,
160            useidf: true,
161            norm: Some("l2".to_string()),
162        }
163    }
164
165    /// Fit the vectorizer on a corpus
166    pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
167        self.count_vectorizer.fit(texts)?;
168
169        if self.useidf {
170            // Calculate IDF values
171            let n_docs = texts.len() as f64;
172            let vocab_size = self.count_vectorizer.vocabulary_size();
173            let mut doc_freq = vec![0.0; vocab_size];
174
175            // Count document frequencies
176            for &text in texts {
177                let sparse_vec = self.count_vectorizer.transform(text)?;
178                for &idx in sparse_vec.indices() {
179                    doc_freq[idx] += 1.0;
180                }
181            }
182
183            // Calculate IDF values: log(n_docs / df) + 1
184            let mut idf_values = Array1::zeros(vocab_size);
185            for (idx, &df) in doc_freq.iter().enumerate() {
186                if df > 0.0 {
187                    idf_values[idx] = (n_docs / df).ln() + 1.0;
188                } else {
189                    idf_values[idx] = 1.0;
190                }
191            }
192
193            self.idf = Some(idf_values);
194        }
195
196        Ok(())
197    }
198
199    /// Transform a single text into a sparse TF-IDF vector
200    pub fn transform(&self, text: &str) -> Result<SparseVector> {
201        let mut sparse_vec = self.count_vectorizer.transform(text)?;
202
203        // Apply IDF weighting if enabled
204        if self.useidf {
205            if let Some(ref idf) = self.idf {
206                let indices_copy: Vec<usize> = sparse_vec.indices().to_vec();
207                let values = sparse_vec.values_mut();
208                for (i, &idx) in indices_copy.iter().enumerate() {
209                    values[i] *= idf[idx];
210                }
211            }
212        }
213
214        // Apply normalization if specified
215        if let Some(ref norm_type) = self.norm {
216            match norm_type.as_str() {
217                "l2" => {
218                    let norm = sparse_vec.norm();
219                    if norm > 0.0 {
220                        sparse_vec.scale(1.0 / norm);
221                    }
222                }
223                "l1" => {
224                    let sum: f64 = sparse_vec.values().iter().map(|x| x.abs()).sum();
225                    if sum > 0.0 {
226                        sparse_vec.scale(1.0 / sum);
227                    }
228                }
229                _ => {
230                    return Err(TextError::InvalidInput(format!(
231                        "Unknown normalization type: {norm_type}"
232                    )));
233                }
234            }
235        }
236
237        Ok(sparse_vec)
238    }
239
240    /// Transform a batch of texts into a sparse TF-IDF matrix
241    pub fn transform_batch(&self, texts: &[&str]) -> Result<CsrMatrix> {
242        let n_cols = self.count_vectorizer.vocabulary_size();
243        let mut builder = SparseMatrixBuilder::new(n_cols);
244
245        for &text in texts {
246            let sparse_vec = self.transform(text)?;
247            builder.add_row(sparse_vec)?;
248        }
249
250        Ok(builder.build())
251    }
252
253    /// Fit and transform in one step
254    pub fn fit_transform(&mut self, texts: &[&str]) -> Result<CsrMatrix> {
255        self.fit(texts)?;
256        self.transform_batch(texts)
257    }
258
259    /// Get vocabulary size
260    pub fn vocabulary_size(&self) -> usize {
261        self.count_vectorizer.vocabulary_size()
262    }
263
264    /// Get the vocabulary
265    pub fn vocabulary(&self) -> &Vocabulary {
266        self.count_vectorizer.vocabulary()
267    }
268
269    /// Get the IDF values
270    pub fn idf_values(&self) -> Option<&Array1<f64>> {
271        self.idf.as_ref()
272    }
273}
274
275impl Default for SparseTfidfVectorizer {
276    fn default() -> Self {
277        Self::new()
278    }
279}
280
281/// Compute cosine similarity between sparse vectors
282#[allow(dead_code)]
283pub fn sparse_cosine_similarity(v1: &SparseVector, v2: &SparseVector) -> Result<f64> {
284    if v1.size() != v2.size() {
285        return Err(TextError::InvalidInput(format!(
286            "Vector dimensions don't match: {} vs {}",
287            v1.size(),
288            v2.size()
289        )));
290    }
291
292    let dot = v1.dotsparse(v2)?;
293    let norm1 = v1.norm();
294    let norm2 = v2.norm();
295
296    if norm1 == 0.0 || norm2 == 0.0 {
297        Ok(if norm1 == norm2 { 1.0 } else { 0.0 })
298    } else {
299        Ok(dot / (norm1 * norm2))
300    }
301}
302
303/// Memory usage statistics for sparse representation
304pub struct MemoryStats {
305    /// Memory used by sparse representation in bytes
306    pub sparse_bytes: usize,
307    /// Memory that would be used by dense representation in bytes
308    pub dense_bytes: usize,
309    /// Compression ratio (dense_bytes / sparse_bytes)
310    pub compression_ratio: f64,
311    /// Sparsity ratio (number of zeros / total elements)
312    pub sparsity: f64,
313}
314
315impl MemoryStats {
316    /// Calculate memory statistics for a sparse matrix
317    pub fn from_sparse_matrix(sparse: &CsrMatrix) -> Self {
318        let (n_rows, n_cols) = sparse.shape();
319        let dense_bytes = n_rows * n_cols * std::mem::size_of::<f64>();
320        let sparse_bytes = sparse.memory_usage();
321        let total_elements = n_rows * n_cols;
322        let nnz = sparse.nnz();
323
324        Self {
325            sparse_bytes,
326            dense_bytes,
327            compression_ratio: dense_bytes as f64 / sparse_bytes as f64,
328            sparsity: 1.0 - (nnz as f64 / total_elements as f64),
329        }
330    }
331
332    /// Print memory statistics
333    pub fn print_stats(&self) {
334        println!("Memory Usage Statistics:");
335        println!("  Sparse representation: {} bytes", self.sparse_bytes);
336        println!("  Dense representation: {} bytes", self.dense_bytes);
337        println!("  Compression ratio: {:.2}x", self.compression_ratio);
338        println!("  Sparsity: {:.1}%", self.sparsity * 100.0);
339        println!(
340            "  Memory saved: {:.1}%",
341            (1.0 - 1.0 / self.compression_ratio) * 100.0
342        );
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_sparse_count_vectorizer() {
352        // Use larger, sparser data to ensure compression benefits
353        let texts = vec![
354            "this is a test document with some unique words",
355            "this is another test document with different vocabulary",
356            "yet another example document with more text content",
357            "completely different text with various other terms",
358            "final document in the test set with distinct words",
359        ];
360
361        let mut vectorizer = SparseCountVectorizer::new(false);
362        let sparse_matrix = vectorizer.fit_transform(&texts).unwrap();
363
364        assert_eq!(sparse_matrix.shape().0, 5); // 5 documents
365        assert!(sparse_matrix.nnz() > 0);
366
367        // Check memory efficiency - with larger vocabulary, sparse should be more efficient
368        let stats = MemoryStats::from_sparse_matrix(&sparse_matrix);
369        // For small test data, just verify it's calculated properly
370        assert!(stats.compression_ratio > 0.0);
371        assert!(stats.sparsity >= 0.0);
372    }
373
374    #[test]
375    fn test_sparse_tfidf_vectorizer() {
376        let texts = vec!["the quick brown fox", "the lazy dog", "brown fox jumps"];
377
378        let mut vectorizer = SparseTfidfVectorizer::new();
379        let sparse_matrix = vectorizer.fit_transform(&texts).unwrap();
380
381        assert_eq!(sparse_matrix.shape().0, 3);
382
383        // Verify TF-IDF properties
384        let first_doc = sparse_matrix.get_row(0).unwrap();
385        assert!(first_doc.norm() > 0.0);
386
387        // With L2 normalization, the norm should be approximately 1
388        assert!((first_doc.norm() - 1.0).abs() < 1e-6);
389    }
390
391    #[test]
392    fn test_sparse_cosine_similarity() {
393        let v1 = SparseVector::fromindices_values(vec![0, 2, 3], vec![1.0, 2.0, 3.0], 5);
394
395        let v2 = SparseVector::fromindices_values(vec![1, 2, 4], vec![1.0, 2.0, 1.0], 5);
396
397        let similarity = sparse_cosine_similarity(&v1, &v2).unwrap();
398
399        // Only index 2 overlaps with value 2.0 in both
400        // v1 dot v2 = 2.0 * 2.0 = 4.0
401        // |v1| = sqrt(1 + 4 + 9) = sqrt(14)
402        // |v2| = sqrt(1 + 4 + 1) = sqrt(6)
403        // cos = 4 / (sqrt(14) * sqrt(6))
404        let expected = 4.0 / (14.0_f64.sqrt() * 6.0_f64.sqrt());
405        assert!((similarity - expected).abs() < 1e-10);
406    }
407
408    #[test]
409    fn test_memory_efficiency_large() {
410        // Create a large corpus with sparse content
411        let texts: Vec<String> = (0..100)
412            .map(|i| {
413                let word_idx = i % 10;
414                format!("document {i} contains word{word_idx}")
415            })
416            .collect();
417
418        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_ref()).collect();
419
420        let mut vectorizer = SparseCountVectorizer::new(false);
421        let sparse_matrix = vectorizer.fit_transform(&text_refs).unwrap();
422
423        let stats = MemoryStats::from_sparse_matrix(&sparse_matrix);
424        stats.print_stats();
425
426        // Should achieve significant compression for sparse data
427        assert!(stats.compression_ratio > 5.0);
428        assert!(stats.sparsity > 0.8);
429    }
430}