Skip to main content

scry_learn/text/
count.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Count-based text vectorizer.
3//!
4//! Converts a collection of text documents into a sparse matrix of token
5//! counts, analogous to scikit-learn's `CountVectorizer`.
6
7use crate::sparse::CsrMatrix;
8use std::collections::HashMap;
9
10/// Converts text documents into a sparse term-count matrix.
11///
12/// Each document becomes a row, each unique token a column. Cell values
13/// are the number of times that token appears in that document.
14///
15/// # Example
16///
17/// ```ignore
18/// use scry_learn::text::CountVectorizer;
19///
20/// let mut cv = CountVectorizer::new();
21/// let docs = ["the cat sat", "the dog sat", "the cat played"];
22/// let matrix = cv.fit_transform(&docs);
23///
24/// assert_eq!(matrix.n_rows(), 3);
25/// assert_eq!(matrix.n_cols(), cv.vocabulary().len());
26/// ```
27#[derive(Debug, Clone)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub struct CountVectorizer {
30    /// Token → column index mapping.
31    vocabulary: HashMap<String, usize>,
32    /// Minimum document frequency (absolute count).
33    min_df: usize,
34    /// Maximum document frequency as a fraction of total documents.
35    max_df: f64,
36    /// N-gram range `(min_n, max_n)`.
37    ngram_range: (usize, usize),
38    /// Maximum number of features (vocabulary size).
39    max_features: Option<usize>,
40    /// If true, all non-zero counts are set to 1.
41    binary: bool,
42    /// Whether fit() has been called.
43    fitted: bool,
44}
45
46impl CountVectorizer {
47    /// Create a new `CountVectorizer` with default settings.
48    pub fn new() -> Self {
49        Self {
50            vocabulary: HashMap::new(),
51            min_df: 1,
52            max_df: 1.0,
53            ngram_range: (1, 1),
54            max_features: None,
55            binary: false,
56            fitted: false,
57        }
58    }
59
60    /// Set minimum document frequency (absolute). Tokens appearing in
61    /// fewer documents are excluded. Default: 1.
62    pub fn min_df(mut self, n: usize) -> Self {
63        self.min_df = n.max(1);
64        self
65    }
66
67    /// Set maximum document frequency as a fraction in `(0.0, 1.0]`.
68    /// Tokens appearing in more than this fraction of documents are
69    /// excluded. Default: 1.0 (no filtering).
70    pub fn max_df(mut self, frac: f64) -> Self {
71        self.max_df = frac.clamp(0.0, 1.0);
72        self
73    }
74
75    /// Set n-gram range. Default: `(1, 1)` (unigrams only).
76    pub fn ngram_range(mut self, min_n: usize, max_n: usize) -> Self {
77        self.ngram_range = (min_n.max(1), max_n.max(min_n.max(1)));
78        self
79    }
80
81    /// Limit vocabulary to the top `n` features by total frequency.
82    /// Default: no limit.
83    pub fn max_features(mut self, n: usize) -> Self {
84        self.max_features = Some(n);
85        self
86    }
87
88    /// If true, all non-zero counts become 1 (presence/absence).
89    /// Default: false.
90    pub fn binary(mut self, b: bool) -> Self {
91        self.binary = b;
92        self
93    }
94
95    /// Learn vocabulary from documents.
96    pub fn fit<S: AsRef<str>>(&mut self, documents: &[S]) {
97        let n_docs = documents.len();
98
99        // Count document frequency for each token.
100        let mut doc_freq: HashMap<String, usize> = HashMap::new();
101        let mut total_freq: HashMap<String, usize> = HashMap::new();
102
103        for doc in documents {
104            let tokens = super::tokenizer::default_tokenize(doc.as_ref());
105            let grams = super::tokenizer::ngrams(&tokens, self.ngram_range);
106
107            // Track unique tokens per document for doc frequency.
108            let mut seen = std::collections::HashSet::new();
109            for gram in &grams {
110                if seen.insert(gram.clone()) {
111                    *doc_freq.entry(gram.clone()).or_insert(0) += 1;
112                }
113                *total_freq.entry(gram.clone()).or_insert(0) += 1;
114            }
115        }
116
117        // Apply min_df / max_df filters.
118        let max_df_abs = (self.max_df * n_docs as f64).ceil() as usize;
119        let mut candidates: Vec<(String, usize)> = total_freq
120            .into_iter()
121            .filter(|(token, _)| {
122                let df = doc_freq.get(token).copied().unwrap_or(0);
123                df >= self.min_df && df <= max_df_abs
124            })
125            .collect();
126
127        // Sort by frequency descending, then alphabetically for stability.
128        candidates.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
129
130        // Apply max_features cap.
131        if let Some(max_f) = self.max_features {
132            candidates.truncate(max_f);
133        }
134
135        // Build vocabulary map.
136        // Sort alphabetically for stable column ordering.
137        candidates.sort_by(|a, b| a.0.cmp(&b.0));
138        self.vocabulary.clear();
139        for (idx, (token, _)) in candidates.into_iter().enumerate() {
140            self.vocabulary.insert(token, idx);
141        }
142
143        self.fitted = true;
144    }
145
146    /// Transform documents into a sparse CSR matrix of counts.
147    ///
148    /// Panics if `fit()` has not been called.
149    pub fn transform<S: AsRef<str>>(&self, documents: &[S]) -> CsrMatrix {
150        assert!(
151            self.fitted,
152            "CountVectorizer: must call fit() before transform()"
153        );
154
155        let n_rows = documents.len();
156        let n_cols = self.vocabulary.len();
157
158        if n_rows == 0 || n_cols == 0 {
159            return CsrMatrix::from_dense(&[]);
160        }
161
162        let mut triplet_rows = Vec::new();
163        let mut triplet_cols = Vec::new();
164        let mut triplet_vals = Vec::new();
165
166        for (row_idx, doc) in documents.iter().enumerate() {
167            let tokens = super::tokenizer::default_tokenize(doc.as_ref());
168            let grams = super::tokenizer::ngrams(&tokens, self.ngram_range);
169
170            // Count occurrences.
171            let mut counts: HashMap<usize, f64> = HashMap::new();
172            for gram in &grams {
173                if let Some(&col) = self.vocabulary.get(gram) {
174                    *counts.entry(col).or_insert(0.0) += 1.0;
175                }
176            }
177
178            for (col, val) in counts {
179                let v = if self.binary { 1.0 } else { val };
180                triplet_rows.push(row_idx);
181                triplet_cols.push(col);
182                triplet_vals.push(v);
183            }
184        }
185
186        CsrMatrix::from_triplets(&triplet_rows, &triplet_cols, &triplet_vals, n_rows, n_cols)
187            .expect("CountVectorizer: internal CSR construction error")
188    }
189
190    /// Fit the vocabulary and transform in one step.
191    pub fn fit_transform<S: AsRef<str>>(&mut self, documents: &[S]) -> CsrMatrix {
192        self.fit(documents);
193        self.transform(documents)
194    }
195
196    /// Return the learned vocabulary (token → column index).
197    pub fn vocabulary(&self) -> &HashMap<String, usize> {
198        &self.vocabulary
199    }
200
201    /// Return feature names sorted by column index.
202    pub fn get_feature_names(&self) -> Vec<String> {
203        let mut pairs: Vec<(&String, &usize)> = self.vocabulary.iter().collect();
204        pairs.sort_by_key(|&(_, &idx)| idx);
205        pairs.into_iter().map(|(name, _)| name.clone()).collect()
206    }
207
208    /// Number of features in the vocabulary.
209    pub fn n_features(&self) -> usize {
210        self.vocabulary.len()
211    }
212
213    /// Whether fit() has been called.
214    pub fn is_fitted(&self) -> bool {
215        self.fitted
216    }
217
218    /// Tokenize and generate n-grams for a single document (internal helper).
219    pub(crate) fn tokenize_doc(&self, text: &str) -> Vec<String> {
220        let tokens = super::tokenizer::default_tokenize(text);
221        super::tokenizer::ngrams(&tokens, self.ngram_range)
222    }
223}
224
225impl Default for CountVectorizer {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231#[cfg(test)]
232#[allow(clippy::float_cmp)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn fit_transform_basic() {
238        let docs = ["the cat sat", "the dog sat", "the cat played"];
239        let mut cv = CountVectorizer::new();
240        let matrix = cv.fit_transform(&docs);
241
242        assert_eq!(matrix.n_rows(), 3);
243        assert_eq!(matrix.n_cols(), cv.vocabulary().len());
244        // "the" appears in all 3 docs
245        assert!(cv.vocabulary().contains_key("the"));
246        assert!(cv.vocabulary().contains_key("cat"));
247        assert!(cv.vocabulary().contains_key("dog"));
248        assert!(cv.vocabulary().contains_key("sat"));
249        assert!(cv.vocabulary().contains_key("played"));
250        assert_eq!(cv.n_features(), 5); // the, cat, dog, sat, played
251    }
252
253    #[test]
254    fn vocabulary_order() {
255        let docs = ["b c a", "a b"];
256        let mut cv = CountVectorizer::new();
257        cv.fit(&docs);
258
259        let names = cv.get_feature_names();
260        assert_eq!(names, vec!["a", "b", "c"]); // alphabetical
261    }
262
263    #[test]
264    fn counts_are_correct() {
265        let docs = ["a a b"];
266        let mut cv = CountVectorizer::new();
267        let matrix = cv.fit_transform(&docs);
268        let dense = matrix.to_dense();
269
270        let a_idx = cv.vocabulary()["a"];
271        let b_idx = cv.vocabulary()["b"];
272        assert_eq!(dense[0][a_idx], 2.0);
273        assert_eq!(dense[0][b_idx], 1.0);
274    }
275
276    #[test]
277    fn binary_mode() {
278        let docs = ["a a a b"];
279        let mut cv = CountVectorizer::new().binary(true);
280        let matrix = cv.fit_transform(&docs);
281        let dense = matrix.to_dense();
282
283        let a_idx = cv.vocabulary()["a"];
284        assert_eq!(dense[0][a_idx], 1.0); // binary: max 1
285    }
286
287    #[test]
288    fn min_df_filters() {
289        let docs = ["a b c", "a b", "a"];
290        let mut cv = CountVectorizer::new().min_df(2);
291        cv.fit(&docs);
292
293        assert!(cv.vocabulary().contains_key("a"));
294        assert!(cv.vocabulary().contains_key("b"));
295        assert!(!cv.vocabulary().contains_key("c")); // only in 1 doc
296    }
297
298    #[test]
299    fn max_df_filters() {
300        let docs = ["a b", "a c", "a d"];
301        let mut cv = CountVectorizer::new().max_df(0.5);
302        cv.fit(&docs);
303
304        // "a" is in 100% of docs (> 50%), should be filtered
305        assert!(!cv.vocabulary().contains_key("a"));
306        assert!(cv.vocabulary().contains_key("b"));
307    }
308
309    #[test]
310    fn max_features_limits() {
311        let docs = ["a a a b b c"];
312        let mut cv = CountVectorizer::new().max_features(2);
313        cv.fit(&docs);
314
315        assert_eq!(cv.n_features(), 2);
316    }
317
318    #[test]
319    fn bigrams() {
320        let docs = ["the cat sat"];
321        let mut cv = CountVectorizer::new().ngram_range(2, 2);
322        let matrix = cv.fit_transform(&docs);
323
324        assert!(cv.vocabulary().contains_key("the cat"));
325        assert!(cv.vocabulary().contains_key("cat sat"));
326        assert_eq!(matrix.n_cols(), 2);
327    }
328
329    #[test]
330    fn unigrams_and_bigrams() {
331        let docs = ["the cat sat"];
332        let mut cv = CountVectorizer::new().ngram_range(1, 2);
333        cv.fit(&docs);
334
335        // 3 unigrams + 2 bigrams = 5 features
336        assert_eq!(cv.n_features(), 5);
337    }
338
339    #[test]
340    fn transform_unseen_terms() {
341        let train = ["the cat sat"];
342        let test = ["the bird flew"];
343
344        let mut cv = CountVectorizer::new();
345        cv.fit(&train);
346
347        let matrix = cv.transform(&test);
348        let dense = matrix.to_dense();
349
350        // "bird" and "flew" are not in vocabulary, should be ignored
351        let the_idx = cv.vocabulary()["the"];
352        assert_eq!(dense[0][the_idx], 1.0);
353
354        // Total non-zero should be 1 (only "the")
355        let nnz: f64 = dense[0].iter().sum();
356        assert_eq!(nnz, 1.0);
357    }
358
359    #[test]
360    fn empty_documents() {
361        let docs: [&str; 0] = [];
362        let mut cv = CountVectorizer::new();
363        let matrix = cv.fit_transform(&docs);
364
365        assert_eq!(matrix.n_rows(), 0);
366        assert_eq!(matrix.n_cols(), 0);
367    }
368
369    #[test]
370    fn string_refs_accepted() {
371        // Verify it works with Vec<String> too, not just &[&str]
372        let docs: Vec<String> = vec!["hello world".into(), "hello test".into()];
373        let mut cv = CountVectorizer::new();
374        let matrix = cv.fit_transform(&docs);
375        assert_eq!(matrix.n_rows(), 2);
376    }
377}