1use crate::error::{Result, TextError};
7use crate::tokenize::{NgramTokenizer, Tokenizer, WordTokenizer};
8use crate::vocabulary::Vocabulary;
9use scirs2_core::ndarray::{Array1, Array2};
10use std::collections::HashMap;
11
12pub struct EnhancedCountVectorizer {
14 vocabulary: Vocabulary,
15 binary: bool,
16 ngram_range: (usize, usize),
17 max_features: Option<usize>,
18 min_df: f64,
19 max_df: f64,
20 lowercase: bool,
21}
22
23impl EnhancedCountVectorizer {
24 pub fn new() -> Self {
26 Self {
27 vocabulary: Vocabulary::new(),
28 binary: false,
29 ngram_range: (1, 1),
30 max_features: None,
31 min_df: 0.0,
32 max_df: 1.0,
33 lowercase: true,
34 }
35 }
36
37 pub fn set_binary(mut self, binary: bool) -> Self {
39 self.binary = binary;
40 self
41 }
42
43 pub fn set_ngram_range(mut self, range: (usize, usize)) -> Result<Self> {
45 if range.0 == 0 || range.1 < range.0 {
46 return Err(TextError::InvalidInput(
47 "Invalid n-gram range. Must have min_n > 0 and max_n >= min_n".to_string(),
48 ));
49 }
50 self.ngram_range = range;
51 Ok(self)
52 }
53
54 pub fn set_max_features(mut self, maxfeatures: Option<usize>) -> Self {
56 self.max_features = maxfeatures;
57 self
58 }
59
60 pub fn set_min_df(mut self, mindf: f64) -> Result<Self> {
62 if !(0.0..=1.0).contains(&mindf) {
63 return Err(TextError::InvalidInput(
64 "min_df must be between 0.0 and 1.0".to_string(),
65 ));
66 }
67 self.min_df = mindf;
68 Ok(self)
69 }
70
71 pub fn set_max_df(mut self, maxdf: f64) -> Result<Self> {
73 if !(0.0..=1.0).contains(&maxdf) {
74 return Err(TextError::InvalidInput(
75 "max_df must be between 0.0 and 1.0".to_string(),
76 ));
77 }
78 self.max_df = maxdf;
79 Ok(self)
80 }
81
82 pub fn set_lowercase(mut self, lowercase: bool) -> Self {
84 self.lowercase = lowercase;
85 self
86 }
87
88 pub fn vocabulary(&self) -> &Vocabulary {
90 &self.vocabulary
91 }
92
93 pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
95 if texts.is_empty() {
96 return Err(TextError::InvalidInput(
97 "No texts provided for fitting".to_string(),
98 ));
99 }
100
101 self.vocabulary = Vocabulary::new();
103
104 let mut doc_frequencies: HashMap<String, usize> = HashMap::new();
106 let total_docs = texts.len();
107
108 for text in texts {
110 let mut seen_in_doc: HashMap<String, bool> = HashMap::new();
111
112 let all_tokens = self.extract_ngrams(text)?;
114
115 for token in all_tokens {
117 if !seen_in_doc.contains_key(&token) {
118 *doc_frequencies.entry(token.clone()).or_insert(0) += 1;
119 seen_in_doc.insert(token.clone(), true);
120 }
121
122 self.vocabulary.add_token(&token);
124 }
125 }
126
127 let min_count = (self.min_df * total_docs as f64).ceil() as usize;
129 let max_count = (self.max_df * total_docs as f64).floor() as usize;
130
131 let mut filtered_tokens: Vec<(String, usize)> = doc_frequencies
132 .into_iter()
133 .filter(|(_, count)| *count >= min_count && *count <= max_count)
134 .collect();
135
136 filtered_tokens.sort_by(|a, b| b.1.cmp(&a.1));
138
139 if let Some(max_features) = self.max_features {
140 filtered_tokens.truncate(max_features);
141 }
142
143 self.vocabulary = Vocabulary::with_maxsize(self.max_features.unwrap_or(usize::MAX));
145 for (token, _) in filtered_tokens {
146 self.vocabulary.add_token(&token);
147 }
148
149 Ok(())
150 }
151
152 fn extract_ngrams(&self, text: &str) -> Result<Vec<String>> {
154 let text = if self.lowercase {
155 text.to_lowercase()
156 } else {
157 text.to_string()
158 };
159
160 let all_ngrams = if self.ngram_range == (1, 1) {
162 let tokenizer = WordTokenizer::new(false);
163 tokenizer.tokenize(&text)?
164 } else {
165 let ngram_tokenizer =
167 NgramTokenizer::with_range(self.ngram_range.0, self.ngram_range.1)?;
168 ngram_tokenizer.tokenize(&text)?
169 };
170
171 Ok(all_ngrams)
172 }
173
174 pub fn transform(&self, text: &str) -> Result<Array1<f64>> {
176 if self.vocabulary.is_empty() {
177 return Err(TextError::VocabularyError(
178 "Vocabulary is empty. Call fit() first".to_string(),
179 ));
180 }
181
182 let vocab_size = self.vocabulary.len();
183 let mut vector = Array1::zeros(vocab_size);
184
185 let tokens = self.extract_ngrams(text)?;
187
188 for token in tokens {
190 if let Some(idx) = self.vocabulary.get_index(&token) {
191 vector[idx] += 1.0;
192 }
193 }
194
195 if self.binary {
197 for val in vector.iter_mut() {
198 if *val > 0.0 {
199 *val = 1.0;
200 }
201 }
202 }
203
204 Ok(vector)
205 }
206
207 pub fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
209 if self.vocabulary.is_empty() {
210 return Err(TextError::VocabularyError(
211 "Vocabulary is empty. Call fit() first".to_string(),
212 ));
213 }
214
215 let n_samples = texts.len();
216 let vocab_size = self.vocabulary.len();
217 let mut matrix = Array2::zeros((n_samples, vocab_size));
218
219 for (i, text) in texts.iter().enumerate() {
220 let vector = self.transform(text)?;
221 matrix.row_mut(i).assign(&vector);
222 }
223
224 Ok(matrix)
225 }
226
227 pub fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
229 self.fit(texts)?;
230 self.transform_batch(texts)
231 }
232}
233
234impl Default for EnhancedCountVectorizer {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240pub struct EnhancedTfidfVectorizer {
242 count_vectorizer: EnhancedCountVectorizer,
243 useidf: bool,
244 smoothidf: bool,
245 sublinear_tf: bool,
246 norm: Option<String>,
247 idf_: Option<Array1<f64>>,
248}
249
250impl EnhancedTfidfVectorizer {
251 pub fn new() -> Self {
253 Self {
254 count_vectorizer: EnhancedCountVectorizer::new(),
255 useidf: true,
256 smoothidf: true,
257 sublinear_tf: false,
258 norm: Some("l2".to_string()),
259 idf_: None,
260 }
261 }
262
263 pub fn set_use_idf(mut self, useidf: bool) -> Self {
265 self.useidf = useidf;
266 self
267 }
268
269 pub fn set_smooth_idf(mut self, smoothidf: bool) -> Self {
271 self.smoothidf = smoothidf;
272 self
273 }
274
275 pub fn set_sublinear_tf(mut self, sublineartf: bool) -> Self {
277 self.sublinear_tf = sublineartf;
278 self
279 }
280
281 pub fn set_norm(mut self, norm: Option<String>) -> Result<Self> {
283 if let Some(ref n) = norm {
284 if n != "l1" && n != "l2" {
285 return Err(TextError::InvalidInput(
286 "Norm must be 'l1', 'l2', or None".to_string(),
287 ));
288 }
289 }
290 self.norm = norm;
291 Ok(self)
292 }
293
294 pub fn set_ngram_range(mut self, range: (usize, usize)) -> Result<Self> {
296 self.count_vectorizer = self.count_vectorizer.set_ngram_range(range)?;
297 Ok(self)
298 }
299
300 pub fn set_max_features(mut self, maxfeatures: Option<usize>) -> Self {
302 self.count_vectorizer = self.count_vectorizer.set_max_features(maxfeatures);
303 self
304 }
305
306 pub fn vocabulary(&self) -> &Vocabulary {
308 self.count_vectorizer.vocabulary()
309 }
310
311 pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
313 self.count_vectorizer.fit(texts)?;
315
316 if self.useidf {
317 self.calculate_idf(texts)?;
319 }
320
321 Ok(())
322 }
323
324 fn calculate_idf(&mut self, texts: &[&str]) -> Result<()> {
326 let vocab_size = self.count_vectorizer.vocabulary().len();
327 let mut df: Array1<f64> = Array1::zeros(vocab_size);
328 let n_samples = texts.len() as f64;
329
330 for text in texts {
332 let count_vec = self.count_vectorizer.transform(text)?;
333 for (idx, &count) in count_vec.iter().enumerate() {
334 if count > 0.0 {
335 df[idx] += 1.0;
336 }
337 }
338 }
339
340 let mut idf = Array1::zeros(vocab_size);
342 for (idx, &doc_freq) in df.iter().enumerate() {
343 if self.smoothidf {
344 idf[idx] = (1.0 + n_samples) / (1.0 + doc_freq);
345 } else {
346 idf[idx] = n_samples / doc_freq.max(1.0);
347 }
348 idf[idx] = idf[idx].ln() + 1.0;
349 }
350
351 self.idf_ = Some(idf);
352 Ok(())
353 }
354
355 pub fn transform(&self, text: &str) -> Result<Array1<f64>> {
357 let mut vector = self.count_vectorizer.transform(text)?;
359
360 if self.sublinear_tf {
362 for val in vector.iter_mut() {
363 if *val > 0.0 {
364 *val = 1.0 + (*val).ln();
365 }
366 }
367 }
368
369 if self.useidf {
371 if let Some(ref idf) = self.idf_ {
372 vector *= idf;
373 } else {
374 return Err(TextError::VocabularyError(
375 "IDF weights not calculated. Call fit() first".to_string(),
376 ));
377 }
378 }
379
380 if let Some(ref norm) = self.norm {
382 match norm.as_str() {
383 "l1" => {
384 let norm_val = vector.iter().map(|x| x.abs()).sum::<f64>();
385 if norm_val > 0.0 {
386 vector /= norm_val;
387 }
388 }
389 "l2" => {
390 let norm_val = vector.dot(&vector).sqrt();
391 if norm_val > 0.0 {
392 vector /= norm_val;
393 }
394 }
395 _ => {}
396 }
397 }
398
399 Ok(vector)
400 }
401
402 pub fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
404 let n_samples = texts.len();
405 let vocab_size = self.count_vectorizer.vocabulary().len();
406 let mut matrix = Array2::zeros((n_samples, vocab_size));
407
408 for (i, text) in texts.iter().enumerate() {
409 let vector = self.transform(text)?;
410 matrix.row_mut(i).assign(&vector);
411 }
412
413 Ok(matrix)
414 }
415
416 pub fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
418 self.fit(texts)?;
419 self.transform_batch(texts)
420 }
421}
422
423impl Default for EnhancedTfidfVectorizer {
424 fn default() -> Self {
425 Self::new()
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_enhanced_count_vectorizer_unigrams() {
435 let mut vectorizer = EnhancedCountVectorizer::new();
436
437 let documents = vec![
438 "this is a test",
439 "this is another test",
440 "something different here",
441 ];
442
443 vectorizer.fit(&documents).unwrap();
444
445 let vector = vectorizer.transform("this is a test").unwrap();
446 assert!(!vector.is_empty());
447 }
448
449 #[test]
450 fn test_enhanced_count_vectorizer_ngrams() {
451 let mut vectorizer = EnhancedCountVectorizer::new()
452 .set_ngram_range((1, 2))
453 .unwrap();
454
455 let documents = vec!["hello world", "hello there", "world peace"];
456
457 vectorizer.fit(&documents).unwrap();
458
459 let vocab = vectorizer.vocabulary();
461 assert!(vocab.len() > 3); }
463
464 #[test]
465 fn test_enhanced_tfidf_vectorizer() {
466 let mut vectorizer = EnhancedTfidfVectorizer::new()
467 .set_smooth_idf(true)
468 .set_norm(Some("l2".to_string()))
469 .unwrap();
470
471 let documents = vec![
472 "this is a test",
473 "this is another test",
474 "something different here",
475 ];
476
477 vectorizer.fit(&documents).unwrap();
478
479 let vector = vectorizer.transform("this is a test").unwrap();
480
481 let norm = vector.dot(&vector).sqrt();
483 assert!((norm - 1.0).abs() < 1e-6);
484 }
485
486 #[test]
487 fn test_max_features() {
488 let mut vectorizer = EnhancedCountVectorizer::new().set_max_features(Some(5));
489
490 let documents = vec![
491 "one two three four five six seven eight",
492 "one two three four five six seven eight nine ten",
493 ];
494
495 vectorizer.fit(&documents).unwrap();
496
497 assert_eq!(vectorizer.vocabulary().len(), 5);
499 }
500
501 #[test]
502 fn test_document_frequency_filtering() {
503 let mut vectorizer = EnhancedCountVectorizer::new().set_min_df(0.5).unwrap(); let documents = vec![
506 "common word rare",
507 "common word unique",
508 "common another distinct",
509 ];
510
511 vectorizer.fit(&documents).unwrap();
512
513 let vocab = vectorizer.vocabulary();
515 assert!(vocab.contains("common"));
516 assert!(!vocab.contains("rare"));
517 assert!(!vocab.contains("unique"));
518 }
519}