1use ahash::AHasher;
7use regex::Regex;
8use scirs2_core::ndarray::{Array1, Array2};
9use std::collections::{HashMap, HashSet};
10use std::hash::{Hash, Hasher};
11
12use crate::error::{Result, TransformError};
13
14pub struct CountVectorizer {
16 vocabulary: HashMap<String, usize>,
18 feature_names: Vec<String>,
20 max_features: Option<usize>,
22 min_df: f64,
24 max_df: f64,
26 lowercase: bool,
28 token_pattern: Regex,
30 stop_words: HashSet<String>,
32 fitted: bool,
34}
35
36impl CountVectorizer {
37 pub fn new() -> Self {
39 CountVectorizer {
40 vocabulary: HashMap::new(),
41 feature_names: Vec::new(),
42 max_features: None,
43 min_df: 1.0,
44 max_df: 1.0,
45 lowercase: true,
46 token_pattern: Regex::new(r"\b\w+\b").unwrap(),
47 stop_words: HashSet::new(),
48 fitted: false,
49 }
50 }
51
52 #[allow(dead_code)]
54 pub fn with_max_features(mut self, maxfeatures: usize) -> Self {
55 self.max_features = Some(maxfeatures);
56 self
57 }
58
59 #[allow(dead_code)]
61 pub fn with_min_df(mut self, mindf: f64) -> Self {
62 self.min_df = mindf;
63 self
64 }
65
66 #[allow(dead_code)]
68 pub fn with_max_df(mut self, maxdf: f64) -> Self {
69 self.max_df = maxdf;
70 self
71 }
72
73 #[allow(dead_code)]
75 pub fn with_lowercase(mut self, lowercase: bool) -> Self {
76 self.lowercase = lowercase;
77 self
78 }
79
80 #[allow(dead_code)]
82 pub fn with_token_pattern(mut self, pattern: &str) -> Result<Self> {
83 self.token_pattern = Regex::new(pattern)
84 .map_err(|e| TransformError::InvalidInput(format!("Invalid regex pattern: {e}")))?;
85 Ok(self)
86 }
87
88 #[allow(dead_code)]
90 pub fn with_stop_words(mut self, stopwords: Vec<String>) -> Self {
91 self.stop_words = stopwords.into_iter().collect();
92 self
93 }
94
95 fn tokenize(&self, doc: &str) -> Vec<String> {
97 let text = if self.lowercase {
98 doc.to_lowercase()
99 } else {
100 doc.to_string()
101 };
102
103 self.token_pattern
104 .find_iter(&text)
105 .map(|m| m.as_str().to_string())
106 .filter(|token| !self.stop_words.contains(token))
107 .collect()
108 }
109
110 pub fn fit(&mut self, documents: &[String]) -> Result<()> {
112 if documents.is_empty() {
113 return Err(TransformError::InvalidInput(
114 "Empty document collection".into(),
115 ));
116 }
117
118 let mut term_doc_freq: HashMap<String, usize> = HashMap::new();
120 let n_docs = documents.len();
121
122 for doc in documents {
123 let tokens: HashSet<String> = self.tokenize(doc).into_iter().collect();
124 for token in tokens {
125 *term_doc_freq.entry(token).or_insert(0) += 1;
126 }
127 }
128
129 let min_df_count = if self.min_df <= 1.0 {
131 self.min_df as usize
132 } else {
133 (self.min_df * n_docs as f64).ceil() as usize
134 };
135
136 let max_df_count = if self.max_df <= 1.0 {
137 (self.max_df * n_docs as f64).floor() as usize
138 } else {
139 self.max_df as usize
140 };
141
142 let mut filtered_terms: Vec<(String, usize)> = term_doc_freq
143 .into_iter()
144 .filter(|(_, freq)| *freq >= min_df_count && *freq <= max_df_count)
145 .collect();
146
147 filtered_terms.sort_by(|a, b| b.1.cmp(&a.1));
149
150 if let Some(max_feat) = self.max_features {
152 filtered_terms.truncate(max_feat);
153 }
154
155 self.vocabulary.clear();
157 self.feature_names.clear();
158
159 for (idx, (term, _freq)) in filtered_terms.into_iter().enumerate() {
160 self.vocabulary.insert(term.clone(), idx);
161 self.feature_names.push(term);
162 }
163
164 self.fitted = true;
165 Ok(())
166 }
167
168 pub fn transform(&self, documents: &[String]) -> Result<Array2<f64>> {
170 if !self.fitted {
171 return Err(TransformError::NotFitted(
172 "CountVectorizer must be fitted before transform".into(),
173 ));
174 }
175
176 let n_samples = documents.len();
177 let n_features = self.vocabulary.len();
178 let mut result = Array2::zeros((n_samples, n_features));
179
180 for (i, doc) in documents.iter().enumerate() {
181 let tokens = self.tokenize(doc);
182 for token in tokens {
183 if let Some(&idx) = self.vocabulary.get(&token) {
184 result[[i, idx]] += 1.0;
185 }
186 }
187 }
188
189 Ok(result)
190 }
191
192 pub fn fit_transform(&mut self, documents: &[String]) -> Result<Array2<f64>> {
194 self.fit(documents)?;
195 self.transform(documents)
196 }
197
198 #[allow(dead_code)]
200 pub fn get_feature_names(&self) -> &[String] {
201 &self.feature_names
202 }
203}
204
205impl Default for CountVectorizer {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211pub struct TfidfVectorizer {
213 count_vectorizer: CountVectorizer,
215 idf: Array1<f64>,
217 use_idf: bool,
219 norm: bool,
221 smooth_idf: bool,
223 sublinear_tf: bool,
225}
226
227impl TfidfVectorizer {
228 pub fn new() -> Self {
230 TfidfVectorizer {
231 count_vectorizer: CountVectorizer::new(),
232 idf: Array1::zeros(0),
233 use_idf: true,
234 norm: true,
235 smooth_idf: true,
236 sublinear_tf: false,
237 }
238 }
239
240 #[allow(dead_code)]
242 pub fn with_use_idf(mut self, useidf: bool) -> Self {
243 self.use_idf = useidf;
244 self
245 }
246
247 #[allow(dead_code)]
249 pub fn with_norm(mut self, norm: bool) -> Self {
250 self.norm = norm;
251 self
252 }
253
254 #[allow(dead_code)]
256 pub fn with_smooth_idf(mut self, smoothidf: bool) -> Self {
257 self.smooth_idf = smoothidf;
258 self
259 }
260
261 #[allow(dead_code)]
263 pub fn with_sublinear_tf(mut self, sublineartf: bool) -> Self {
264 self.sublinear_tf = sublineartf;
265 self
266 }
267
268 #[allow(dead_code)]
270 pub fn configure_count_vectorizer<F>(mut self, f: F) -> Self
271 where
272 F: FnOnce(CountVectorizer) -> CountVectorizer,
273 {
274 self.count_vectorizer = f(self.count_vectorizer);
275 self
276 }
277
278 pub fn fit(&mut self, documents: &[String]) -> Result<()> {
280 self.count_vectorizer.fit(documents)?;
282
283 if self.use_idf {
284 let n_samples = documents.len() as f64;
286 let n_features = self.count_vectorizer.vocabulary.len();
287 let mut df = Array1::zeros(n_features);
288
289 for doc in documents {
291 let tokens: HashSet<String> =
292 self.count_vectorizer.tokenize(doc).into_iter().collect();
293 for token in tokens {
294 if let Some(&idx) = self.count_vectorizer.vocabulary.get(&token) {
295 df[idx] += 1.0;
296 }
297 }
298 }
299
300 if self.smooth_idf {
302 self.idf = df.mapv(|d: f64| ((n_samples + 1.0) / (d + 1.0)).ln() + 1.0);
303 } else {
304 self.idf = df.mapv(|d: f64| (n_samples / d).ln() + 1.0);
305 }
306 }
307
308 Ok(())
309 }
310
311 pub fn transform(&self, documents: &[String]) -> Result<Array2<f64>> {
313 let mut x = self.count_vectorizer.transform(documents)?;
315
316 if self.sublinear_tf {
318 x.mapv_inplace(|v| if v > 0.0 { 1.0 + v.ln() } else { 0.0 });
319 }
320
321 if self.use_idf {
323 for i in 0..x.shape()[0] {
324 for j in 0..x.shape()[1] {
325 x[[i, j]] *= self.idf[j];
326 }
327 }
328 }
329
330 if self.norm {
332 for i in 0..x.shape()[0] {
333 let row = x.row(i);
334 let norm = row.dot(&row).sqrt();
335 if norm > 0.0 {
336 x.row_mut(i).mapv_inplace(|v| v / norm);
337 }
338 }
339 }
340
341 Ok(x)
342 }
343
344 pub fn fit_transform(&mut self, documents: &[String]) -> Result<Array2<f64>> {
346 self.fit(documents)?;
347 self.transform(documents)
348 }
349
350 #[allow(dead_code)]
352 pub fn get_feature_names(&self) -> &[String] {
353 self.count_vectorizer.get_feature_names()
354 }
355}
356
357impl Default for TfidfVectorizer {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363pub struct HashingVectorizer {
365 n_features: usize,
367 lowercase: bool,
369 token_pattern: Regex,
371 binary: bool,
373 norm: Option<String>,
375}
376
377impl HashingVectorizer {
378 pub fn new(_nfeatures: usize) -> Self {
380 HashingVectorizer {
381 n_features: _nfeatures,
382 lowercase: true,
383 token_pattern: Regex::new(r"\b\w+\b").unwrap(),
384 binary: false,
385 norm: Some("l2".to_string()),
386 }
387 }
388
389 #[allow(dead_code)]
391 pub fn with_binary(mut self, binary: bool) -> Self {
392 self.binary = binary;
393 self
394 }
395
396 #[allow(dead_code)]
398 pub fn with_norm(mut self, norm: Option<String>) -> Self {
399 self.norm = norm;
400 self
401 }
402
403 #[allow(dead_code)]
405 pub fn with_lowercase(mut self, lowercase: bool) -> Self {
406 self.lowercase = lowercase;
407 self
408 }
409
410 fn hash_token(&self, token: &str) -> usize {
412 let mut hasher = AHasher::default();
413 token.hash(&mut hasher);
414 (hasher.finish() as usize) % self.n_features
415 }
416
417 fn tokenize(&self, doc: &str) -> Vec<String> {
419 let text = if self.lowercase {
420 doc.to_lowercase()
421 } else {
422 doc.to_string()
423 };
424
425 self.token_pattern
426 .find_iter(&text)
427 .map(|m| m.as_str().to_string())
428 .collect()
429 }
430
431 pub fn transform(&self, documents: &[String]) -> Result<Array2<f64>> {
433 let n_samples = documents.len();
434 let mut result = Array2::zeros((n_samples, self.n_features));
435
436 for (i, doc) in documents.iter().enumerate() {
437 let tokens = self.tokenize(doc);
438
439 if self.binary {
440 let unique_indices: HashSet<usize> =
441 tokens.iter().map(|token| self.hash_token(token)).collect();
442
443 for idx in unique_indices {
444 result[[i, idx]] = 1.0;
445 }
446 } else {
447 for token in tokens {
448 let idx = self.hash_token(&token);
449 result[[i, idx]] += 1.0;
450 }
451 }
452
453 if let Some(ref norm_type) = self.norm {
455 let row = result.row(i).to_owned();
456 let norm_value = match norm_type.as_str() {
457 "l1" => row.iter().map(|v: &f64| v.abs()).sum::<f64>(),
458 "l2" => row.dot(&row).sqrt(),
459 _ => continue,
460 };
461
462 if norm_value > 0.0 {
463 result.row_mut(i).mapv_inplace(|v| v / norm_value);
464 }
465 }
466 }
467
468 Ok(result)
469 }
470}
471
472pub struct StreamingCountVectorizer {
474 vocabulary: HashMap<String, usize>,
476 doc_freq: HashMap<String, usize>,
478 n_docs_seen: usize,
480 max_features: Option<usize>,
482 lowercase: bool,
484 token_pattern: Regex,
486}
487
488impl StreamingCountVectorizer {
489 pub fn new() -> Self {
491 StreamingCountVectorizer {
492 vocabulary: HashMap::new(),
493 doc_freq: HashMap::new(),
494 n_docs_seen: 0,
495 max_features: None,
496 lowercase: true,
497 token_pattern: Regex::new(r"\b\w+\b").unwrap(),
498 }
499 }
500
501 #[allow(dead_code)]
503 pub fn with_max_features(mut self, maxfeatures: usize) -> Self {
504 self.max_features = Some(maxfeatures);
505 self
506 }
507
508 fn tokenize(&self, doc: &str) -> Vec<String> {
510 let text = if self.lowercase {
511 doc.to_lowercase()
512 } else {
513 doc.to_string()
514 };
515
516 self.token_pattern
517 .find_iter(&text)
518 .map(|m| m.as_str().to_string())
519 .collect()
520 }
521
522 pub fn partial_fit(&mut self, documents: &[String]) -> Result<()> {
524 for doc in documents {
525 self.n_docs_seen += 1;
526 let tokens: HashSet<String> = self.tokenize(doc).into_iter().collect();
527
528 for token in tokens {
529 *self.doc_freq.entry(token.clone()).or_insert(0) += 1;
530
531 if !self.vocabulary.contains_key(&token) {
532 if let Some(max_feat) = self.max_features {
533 if self.vocabulary.len() >= max_feat {
534 if let Some((min_token_, _)) = self
536 .vocabulary
537 .iter()
538 .min_by_key(|(t, _)| self.doc_freq.get(*t).unwrap_or(&0))
539 {
540 let min_token = min_token_.clone();
541 let min_freq = self.doc_freq.get(&min_token).unwrap_or(&0);
542 let new_freq = self.doc_freq.get(&token).unwrap_or(&0);
543
544 if new_freq > min_freq {
545 let old_idx = self.vocabulary.remove(&min_token).unwrap();
546 self.vocabulary.insert(token, old_idx);
547 }
548 }
549 } else {
550 self.vocabulary.insert(token, self.vocabulary.len());
551 }
552 } else {
553 self.vocabulary.insert(token, self.vocabulary.len());
554 }
555 }
556 }
557 }
558
559 Ok(())
560 }
561
562 pub fn transform(&self, documents: &[String]) -> Result<Array2<f64>> {
564 let n_samples = documents.len();
565 let n_features = self.vocabulary.len();
566
567 if n_features == 0 {
568 return Err(TransformError::NotFitted(
569 "No vocabulary learned yet".into(),
570 ));
571 }
572
573 let mut result = Array2::zeros((n_samples, n_features));
574
575 for (i, doc) in documents.iter().enumerate() {
576 let tokens = self.tokenize(doc);
577 for token in tokens {
578 if let Some(&idx) = self.vocabulary.get(&token) {
579 result[[i, idx]] += 1.0;
580 }
581 }
582 }
583
584 Ok(result)
585 }
586}
587
588impl Default for StreamingCountVectorizer {
589 fn default() -> Self {
590 Self::new()
591 }
592}