1use crate::error::{Result, TextError};
7use crate::sparse::{CsrMatrix, SparseMatrixBuilder, SparseVector};
8use crate::tokenize::{Tokenizer, WordTokenizer};
9use crate::vocabulary::Vocabulary;
10use scirs2_core::ndarray::Array1;
11use std::collections::HashMap;
12
13pub struct SparseCountVectorizer {
15 tokenizer: Box<dyn Tokenizer + Send + Sync>,
16 vocabulary: Vocabulary,
17 binary: bool,
18}
19
20impl Clone for SparseCountVectorizer {
21 fn clone(&self) -> Self {
22 Self {
23 tokenizer: self.tokenizer.clone_box(),
24 vocabulary: self.vocabulary.clone(),
25 binary: self.binary,
26 }
27 }
28}
29
30impl SparseCountVectorizer {
31 pub fn new(binary: bool) -> Self {
33 Self {
34 tokenizer: Box::new(WordTokenizer::default()),
35 vocabulary: Vocabulary::new(),
36 binary,
37 }
38 }
39
40 pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>, binary: bool) -> Self {
42 Self {
43 tokenizer,
44 vocabulary: Vocabulary::new(),
45 binary,
46 }
47 }
48
49 pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
51 if texts.is_empty() {
52 return Err(TextError::InvalidInput(
53 "No texts provided for fitting".into(),
54 ));
55 }
56
57 self.vocabulary = Vocabulary::new();
58
59 for &text in texts {
60 let tokens = self.tokenizer.tokenize(text)?;
61 for token in tokens {
62 self.vocabulary.add_token(&token);
63 }
64 }
65
66 Ok(())
67 }
68
69 pub fn transform(&self, text: &str) -> Result<SparseVector> {
71 let tokens = self.tokenizer.tokenize(text)?;
72 let mut counts: HashMap<usize, f64> = HashMap::new();
73
74 for token in tokens {
75 if let Some(idx) = self.vocabulary.get_index(&token) {
76 *counts.entry(idx).or_insert(0.0) += 1.0;
77 }
78 }
79
80 let mut indices: Vec<usize> = counts.keys().copied().collect();
82 indices.sort_unstable();
83
84 let values: Vec<f64> = if self.binary {
85 indices.iter().map(|_| 1.0).collect()
86 } else {
87 indices.iter().map(|&idx| counts[&idx]).collect()
88 };
89
90 let sparse_vec = SparseVector::fromindices_values(indices, values, self.vocabulary.len());
91
92 Ok(sparse_vec)
93 }
94
95 pub fn transform_batch(&self, texts: &[&str]) -> Result<CsrMatrix> {
97 let n_cols = self.vocabulary.len();
98 let mut builder = SparseMatrixBuilder::new(n_cols);
99
100 for &text in texts {
101 let sparse_vec = self.transform(text)?;
102 builder.add_row(sparse_vec)?;
103 }
104
105 Ok(builder.build())
106 }
107
108 pub fn fit_transform(&mut self, texts: &[&str]) -> Result<CsrMatrix> {
110 self.fit(texts)?;
111 self.transform_batch(texts)
112 }
113
114 pub fn vocabulary_size(&self) -> usize {
116 self.vocabulary.len()
117 }
118
119 pub fn vocabulary(&self) -> &Vocabulary {
121 &self.vocabulary
122 }
123}
124
125#[derive(Clone)]
127pub struct SparseTfidfVectorizer {
128 count_vectorizer: SparseCountVectorizer,
129 idf: Option<Array1<f64>>,
130 useidf: bool,
131 norm: Option<String>,
132}
133
134impl SparseTfidfVectorizer {
135 pub fn new() -> Self {
137 Self {
138 count_vectorizer: SparseCountVectorizer::new(false),
139 idf: None,
140 useidf: true,
141 norm: Some("l2".to_string()),
142 }
143 }
144
145 pub fn with_settings(useidf: bool, norm: Option<String>) -> Self {
147 Self {
148 count_vectorizer: SparseCountVectorizer::new(false),
149 idf: None,
150 useidf,
151 norm,
152 }
153 }
154
155 pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
157 Self {
158 count_vectorizer: SparseCountVectorizer::with_tokenizer(tokenizer, false),
159 idf: None,
160 useidf: true,
161 norm: Some("l2".to_string()),
162 }
163 }
164
165 pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
167 self.count_vectorizer.fit(texts)?;
168
169 if self.useidf {
170 let n_docs = texts.len() as f64;
172 let vocab_size = self.count_vectorizer.vocabulary_size();
173 let mut doc_freq = vec![0.0; vocab_size];
174
175 for &text in texts {
177 let sparse_vec = self.count_vectorizer.transform(text)?;
178 for &idx in sparse_vec.indices() {
179 doc_freq[idx] += 1.0;
180 }
181 }
182
183 let mut idf_values = Array1::zeros(vocab_size);
185 for (idx, &df) in doc_freq.iter().enumerate() {
186 if df > 0.0 {
187 idf_values[idx] = (n_docs / df).ln() + 1.0;
188 } else {
189 idf_values[idx] = 1.0;
190 }
191 }
192
193 self.idf = Some(idf_values);
194 }
195
196 Ok(())
197 }
198
199 pub fn transform(&self, text: &str) -> Result<SparseVector> {
201 let mut sparse_vec = self.count_vectorizer.transform(text)?;
202
203 if self.useidf {
205 if let Some(ref idf) = self.idf {
206 let indices_copy: Vec<usize> = sparse_vec.indices().to_vec();
207 let values = sparse_vec.values_mut();
208 for (i, &idx) in indices_copy.iter().enumerate() {
209 values[i] *= idf[idx];
210 }
211 }
212 }
213
214 if let Some(ref norm_type) = self.norm {
216 match norm_type.as_str() {
217 "l2" => {
218 let norm = sparse_vec.norm();
219 if norm > 0.0 {
220 sparse_vec.scale(1.0 / norm);
221 }
222 }
223 "l1" => {
224 let sum: f64 = sparse_vec.values().iter().map(|x| x.abs()).sum();
225 if sum > 0.0 {
226 sparse_vec.scale(1.0 / sum);
227 }
228 }
229 _ => {
230 return Err(TextError::InvalidInput(format!(
231 "Unknown normalization type: {norm_type}"
232 )));
233 }
234 }
235 }
236
237 Ok(sparse_vec)
238 }
239
240 pub fn transform_batch(&self, texts: &[&str]) -> Result<CsrMatrix> {
242 let n_cols = self.count_vectorizer.vocabulary_size();
243 let mut builder = SparseMatrixBuilder::new(n_cols);
244
245 for &text in texts {
246 let sparse_vec = self.transform(text)?;
247 builder.add_row(sparse_vec)?;
248 }
249
250 Ok(builder.build())
251 }
252
253 pub fn fit_transform(&mut self, texts: &[&str]) -> Result<CsrMatrix> {
255 self.fit(texts)?;
256 self.transform_batch(texts)
257 }
258
259 pub fn vocabulary_size(&self) -> usize {
261 self.count_vectorizer.vocabulary_size()
262 }
263
264 pub fn vocabulary(&self) -> &Vocabulary {
266 self.count_vectorizer.vocabulary()
267 }
268
269 pub fn idf_values(&self) -> Option<&Array1<f64>> {
271 self.idf.as_ref()
272 }
273}
274
275impl Default for SparseTfidfVectorizer {
276 fn default() -> Self {
277 Self::new()
278 }
279}
280
281#[allow(dead_code)]
283pub fn sparse_cosine_similarity(v1: &SparseVector, v2: &SparseVector) -> Result<f64> {
284 if v1.size() != v2.size() {
285 return Err(TextError::InvalidInput(format!(
286 "Vector dimensions don't match: {} vs {}",
287 v1.size(),
288 v2.size()
289 )));
290 }
291
292 let dot = v1.dotsparse(v2)?;
293 let norm1 = v1.norm();
294 let norm2 = v2.norm();
295
296 if norm1 == 0.0 || norm2 == 0.0 {
297 Ok(if norm1 == norm2 { 1.0 } else { 0.0 })
298 } else {
299 Ok(dot / (norm1 * norm2))
300 }
301}
302
303pub struct MemoryStats {
305 pub sparse_bytes: usize,
307 pub dense_bytes: usize,
309 pub compression_ratio: f64,
311 pub sparsity: f64,
313}
314
315impl MemoryStats {
316 pub fn from_sparse_matrix(sparse: &CsrMatrix) -> Self {
318 let (n_rows, n_cols) = sparse.shape();
319 let dense_bytes = n_rows * n_cols * std::mem::size_of::<f64>();
320 let sparse_bytes = sparse.memory_usage();
321 let total_elements = n_rows * n_cols;
322 let nnz = sparse.nnz();
323
324 Self {
325 sparse_bytes,
326 dense_bytes,
327 compression_ratio: dense_bytes as f64 / sparse_bytes as f64,
328 sparsity: 1.0 - (nnz as f64 / total_elements as f64),
329 }
330 }
331
332 pub fn print_stats(&self) {
334 println!("Memory Usage Statistics:");
335 println!(" Sparse representation: {} bytes", self.sparse_bytes);
336 println!(" Dense representation: {} bytes", self.dense_bytes);
337 println!(" Compression ratio: {:.2}x", self.compression_ratio);
338 println!(" Sparsity: {:.1}%", self.sparsity * 100.0);
339 println!(
340 " Memory saved: {:.1}%",
341 (1.0 - 1.0 / self.compression_ratio) * 100.0
342 );
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_sparse_count_vectorizer() {
352 let texts = vec![
354 "this is a test document with some unique words",
355 "this is another test document with different vocabulary",
356 "yet another example document with more text content",
357 "completely different text with various other terms",
358 "final document in the test set with distinct words",
359 ];
360
361 let mut vectorizer = SparseCountVectorizer::new(false);
362 let sparse_matrix = vectorizer.fit_transform(&texts).unwrap();
363
364 assert_eq!(sparse_matrix.shape().0, 5); assert!(sparse_matrix.nnz() > 0);
366
367 let stats = MemoryStats::from_sparse_matrix(&sparse_matrix);
369 assert!(stats.compression_ratio > 0.0);
371 assert!(stats.sparsity >= 0.0);
372 }
373
374 #[test]
375 fn test_sparse_tfidf_vectorizer() {
376 let texts = vec!["the quick brown fox", "the lazy dog", "brown fox jumps"];
377
378 let mut vectorizer = SparseTfidfVectorizer::new();
379 let sparse_matrix = vectorizer.fit_transform(&texts).unwrap();
380
381 assert_eq!(sparse_matrix.shape().0, 3);
382
383 let first_doc = sparse_matrix.get_row(0).unwrap();
385 assert!(first_doc.norm() > 0.0);
386
387 assert!((first_doc.norm() - 1.0).abs() < 1e-6);
389 }
390
391 #[test]
392 fn test_sparse_cosine_similarity() {
393 let v1 = SparseVector::fromindices_values(vec![0, 2, 3], vec![1.0, 2.0, 3.0], 5);
394
395 let v2 = SparseVector::fromindices_values(vec![1, 2, 4], vec![1.0, 2.0, 1.0], 5);
396
397 let similarity = sparse_cosine_similarity(&v1, &v2).unwrap();
398
399 let expected = 4.0 / (14.0_f64.sqrt() * 6.0_f64.sqrt());
405 assert!((similarity - expected).abs() < 1e-10);
406 }
407
408 #[test]
409 fn test_memory_efficiency_large() {
410 let texts: Vec<String> = (0..100)
412 .map(|i| {
413 let word_idx = i % 10;
414 format!("document {i} contains word{word_idx}")
415 })
416 .collect();
417
418 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_ref()).collect();
419
420 let mut vectorizer = SparseCountVectorizer::new(false);
421 let sparse_matrix = vectorizer.fit_transform(&text_refs).unwrap();
422
423 let stats = MemoryStats::from_sparse_matrix(&sparse_matrix);
424 stats.print_stats();
425
426 assert!(stats.compression_ratio > 5.0);
428 assert!(stats.sparsity > 0.8);
429 }
430}