1use crate::sparse::CsrMatrix;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub struct CountVectorizer {
30 vocabulary: HashMap<String, usize>,
32 min_df: usize,
34 max_df: f64,
36 ngram_range: (usize, usize),
38 max_features: Option<usize>,
40 binary: bool,
42 fitted: bool,
44}
45
46impl CountVectorizer {
47 pub fn new() -> Self {
49 Self {
50 vocabulary: HashMap::new(),
51 min_df: 1,
52 max_df: 1.0,
53 ngram_range: (1, 1),
54 max_features: None,
55 binary: false,
56 fitted: false,
57 }
58 }
59
60 pub fn min_df(mut self, n: usize) -> Self {
63 self.min_df = n.max(1);
64 self
65 }
66
67 pub fn max_df(mut self, frac: f64) -> Self {
71 self.max_df = frac.clamp(0.0, 1.0);
72 self
73 }
74
75 pub fn ngram_range(mut self, min_n: usize, max_n: usize) -> Self {
77 self.ngram_range = (min_n.max(1), max_n.max(min_n.max(1)));
78 self
79 }
80
81 pub fn max_features(mut self, n: usize) -> Self {
84 self.max_features = Some(n);
85 self
86 }
87
88 pub fn binary(mut self, b: bool) -> Self {
91 self.binary = b;
92 self
93 }
94
95 pub fn fit<S: AsRef<str>>(&mut self, documents: &[S]) {
97 let n_docs = documents.len();
98
99 let mut doc_freq: HashMap<String, usize> = HashMap::new();
101 let mut total_freq: HashMap<String, usize> = HashMap::new();
102
103 for doc in documents {
104 let tokens = super::tokenizer::default_tokenize(doc.as_ref());
105 let grams = super::tokenizer::ngrams(&tokens, self.ngram_range);
106
107 let mut seen = std::collections::HashSet::new();
109 for gram in &grams {
110 if seen.insert(gram.clone()) {
111 *doc_freq.entry(gram.clone()).or_insert(0) += 1;
112 }
113 *total_freq.entry(gram.clone()).or_insert(0) += 1;
114 }
115 }
116
117 let max_df_abs = (self.max_df * n_docs as f64).ceil() as usize;
119 let mut candidates: Vec<(String, usize)> = total_freq
120 .into_iter()
121 .filter(|(token, _)| {
122 let df = doc_freq.get(token).copied().unwrap_or(0);
123 df >= self.min_df && df <= max_df_abs
124 })
125 .collect();
126
127 candidates.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
129
130 if let Some(max_f) = self.max_features {
132 candidates.truncate(max_f);
133 }
134
135 candidates.sort_by(|a, b| a.0.cmp(&b.0));
138 self.vocabulary.clear();
139 for (idx, (token, _)) in candidates.into_iter().enumerate() {
140 self.vocabulary.insert(token, idx);
141 }
142
143 self.fitted = true;
144 }
145
146 pub fn transform<S: AsRef<str>>(&self, documents: &[S]) -> CsrMatrix {
150 assert!(
151 self.fitted,
152 "CountVectorizer: must call fit() before transform()"
153 );
154
155 let n_rows = documents.len();
156 let n_cols = self.vocabulary.len();
157
158 if n_rows == 0 || n_cols == 0 {
159 return CsrMatrix::from_dense(&[]);
160 }
161
162 let mut triplet_rows = Vec::new();
163 let mut triplet_cols = Vec::new();
164 let mut triplet_vals = Vec::new();
165
166 for (row_idx, doc) in documents.iter().enumerate() {
167 let tokens = super::tokenizer::default_tokenize(doc.as_ref());
168 let grams = super::tokenizer::ngrams(&tokens, self.ngram_range);
169
170 let mut counts: HashMap<usize, f64> = HashMap::new();
172 for gram in &grams {
173 if let Some(&col) = self.vocabulary.get(gram) {
174 *counts.entry(col).or_insert(0.0) += 1.0;
175 }
176 }
177
178 for (col, val) in counts {
179 let v = if self.binary { 1.0 } else { val };
180 triplet_rows.push(row_idx);
181 triplet_cols.push(col);
182 triplet_vals.push(v);
183 }
184 }
185
186 CsrMatrix::from_triplets(&triplet_rows, &triplet_cols, &triplet_vals, n_rows, n_cols)
187 .expect("CountVectorizer: internal CSR construction error")
188 }
189
190 pub fn fit_transform<S: AsRef<str>>(&mut self, documents: &[S]) -> CsrMatrix {
192 self.fit(documents);
193 self.transform(documents)
194 }
195
196 pub fn vocabulary(&self) -> &HashMap<String, usize> {
198 &self.vocabulary
199 }
200
201 pub fn get_feature_names(&self) -> Vec<String> {
203 let mut pairs: Vec<(&String, &usize)> = self.vocabulary.iter().collect();
204 pairs.sort_by_key(|&(_, &idx)| idx);
205 pairs.into_iter().map(|(name, _)| name.clone()).collect()
206 }
207
208 pub fn n_features(&self) -> usize {
210 self.vocabulary.len()
211 }
212
213 pub fn is_fitted(&self) -> bool {
215 self.fitted
216 }
217
218 pub(crate) fn tokenize_doc(&self, text: &str) -> Vec<String> {
220 let tokens = super::tokenizer::default_tokenize(text);
221 super::tokenizer::ngrams(&tokens, self.ngram_range)
222 }
223}
224
225impl Default for CountVectorizer {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231#[cfg(test)]
232#[allow(clippy::float_cmp)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn fit_transform_basic() {
238 let docs = ["the cat sat", "the dog sat", "the cat played"];
239 let mut cv = CountVectorizer::new();
240 let matrix = cv.fit_transform(&docs);
241
242 assert_eq!(matrix.n_rows(), 3);
243 assert_eq!(matrix.n_cols(), cv.vocabulary().len());
244 assert!(cv.vocabulary().contains_key("the"));
246 assert!(cv.vocabulary().contains_key("cat"));
247 assert!(cv.vocabulary().contains_key("dog"));
248 assert!(cv.vocabulary().contains_key("sat"));
249 assert!(cv.vocabulary().contains_key("played"));
250 assert_eq!(cv.n_features(), 5); }
252
253 #[test]
254 fn vocabulary_order() {
255 let docs = ["b c a", "a b"];
256 let mut cv = CountVectorizer::new();
257 cv.fit(&docs);
258
259 let names = cv.get_feature_names();
260 assert_eq!(names, vec!["a", "b", "c"]); }
262
263 #[test]
264 fn counts_are_correct() {
265 let docs = ["a a b"];
266 let mut cv = CountVectorizer::new();
267 let matrix = cv.fit_transform(&docs);
268 let dense = matrix.to_dense();
269
270 let a_idx = cv.vocabulary()["a"];
271 let b_idx = cv.vocabulary()["b"];
272 assert_eq!(dense[0][a_idx], 2.0);
273 assert_eq!(dense[0][b_idx], 1.0);
274 }
275
276 #[test]
277 fn binary_mode() {
278 let docs = ["a a a b"];
279 let mut cv = CountVectorizer::new().binary(true);
280 let matrix = cv.fit_transform(&docs);
281 let dense = matrix.to_dense();
282
283 let a_idx = cv.vocabulary()["a"];
284 assert_eq!(dense[0][a_idx], 1.0); }
286
287 #[test]
288 fn min_df_filters() {
289 let docs = ["a b c", "a b", "a"];
290 let mut cv = CountVectorizer::new().min_df(2);
291 cv.fit(&docs);
292
293 assert!(cv.vocabulary().contains_key("a"));
294 assert!(cv.vocabulary().contains_key("b"));
295 assert!(!cv.vocabulary().contains_key("c")); }
297
298 #[test]
299 fn max_df_filters() {
300 let docs = ["a b", "a c", "a d"];
301 let mut cv = CountVectorizer::new().max_df(0.5);
302 cv.fit(&docs);
303
304 assert!(!cv.vocabulary().contains_key("a"));
306 assert!(cv.vocabulary().contains_key("b"));
307 }
308
309 #[test]
310 fn max_features_limits() {
311 let docs = ["a a a b b c"];
312 let mut cv = CountVectorizer::new().max_features(2);
313 cv.fit(&docs);
314
315 assert_eq!(cv.n_features(), 2);
316 }
317
318 #[test]
319 fn bigrams() {
320 let docs = ["the cat sat"];
321 let mut cv = CountVectorizer::new().ngram_range(2, 2);
322 let matrix = cv.fit_transform(&docs);
323
324 assert!(cv.vocabulary().contains_key("the cat"));
325 assert!(cv.vocabulary().contains_key("cat sat"));
326 assert_eq!(matrix.n_cols(), 2);
327 }
328
329 #[test]
330 fn unigrams_and_bigrams() {
331 let docs = ["the cat sat"];
332 let mut cv = CountVectorizer::new().ngram_range(1, 2);
333 cv.fit(&docs);
334
335 assert_eq!(cv.n_features(), 5);
337 }
338
339 #[test]
340 fn transform_unseen_terms() {
341 let train = ["the cat sat"];
342 let test = ["the bird flew"];
343
344 let mut cv = CountVectorizer::new();
345 cv.fit(&train);
346
347 let matrix = cv.transform(&test);
348 let dense = matrix.to_dense();
349
350 let the_idx = cv.vocabulary()["the"];
352 assert_eq!(dense[0][the_idx], 1.0);
353
354 let nnz: f64 = dense[0].iter().sum();
356 assert_eq!(nnz, 1.0);
357 }
358
359 #[test]
360 fn empty_documents() {
361 let docs: [&str; 0] = [];
362 let mut cv = CountVectorizer::new();
363 let matrix = cv.fit_transform(&docs);
364
365 assert_eq!(matrix.n_rows(), 0);
366 assert_eq!(matrix.n_cols(), 0);
367 }
368
369 #[test]
370 fn string_refs_accepted() {
371 let docs: Vec<String> = vec!["hello world".into(), "hello test".into()];
373 let mut cv = CountVectorizer::new();
374 let matrix = cv.fit_transform(&docs);
375 assert_eq!(matrix.n_rows(), 2);
376 }
377}