Skip to main content

scry_learn/text/
tfidf.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! TF-IDF text vectorizer.
3//!
4//! Term Frequency–Inverse Document Frequency weighting, built on top of
5//! [`CountVectorizer`]. Analogous to scikit-learn's `TfidfVectorizer`.
6
7use super::count::CountVectorizer;
8use crate::sparse::CsrMatrix;
9
10/// Normalization method for TF-IDF vectors.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TfidfNorm {
13    /// L1 normalization (sum of absolute values = 1).
14    L1,
15    /// L2 normalization (Euclidean length = 1). Default.
16    L2,
17    /// No normalization.
18    None,
19}
20
21/// TF-IDF text vectorizer.
22///
23/// Combines count vectorization with IDF weighting and optional
24/// normalization. Produces a sparse CSR matrix.
25///
26/// # Example
27///
28/// ```ignore
29/// use scry_learn::text::TfidfVectorizer;
30///
31/// let docs = ["the cat sat", "the dog sat", "the cat played"];
32/// let mut tfidf = TfidfVectorizer::new();
33/// let matrix = tfidf.fit_transform(&docs);
34/// ```
35#[derive(Debug, Clone)]
36pub struct TfidfVectorizer {
37    /// Underlying count vectorizer.
38    count: CountVectorizer,
39    /// Learned IDF weights (one per vocabulary term).
40    idf_values: Vec<f64>,
41    /// Normalization method.
42    norm: TfidfNorm,
43    /// If true, use `1 + log(tf)` instead of raw `tf`.
44    sublinear_tf: bool,
45    /// If true, add 1 to document frequencies to prevent zero division.
46    smooth_idf: bool,
47    /// Whether fit() has been called.
48    fitted: bool,
49}
50
51impl TfidfVectorizer {
52    /// Create a new `TfidfVectorizer` with default settings.
53    pub fn new() -> Self {
54        Self {
55            count: CountVectorizer::new(),
56            idf_values: Vec::new(),
57            norm: TfidfNorm::L2,
58            sublinear_tf: false,
59            smooth_idf: true,
60            fitted: false,
61        }
62    }
63
64    /// Set minimum document frequency.
65    pub fn min_df(mut self, n: usize) -> Self {
66        self.count = self.count.min_df(n);
67        self
68    }
69
70    /// Set maximum document frequency fraction.
71    pub fn max_df(mut self, frac: f64) -> Self {
72        self.count = self.count.max_df(frac);
73        self
74    }
75
76    /// Set n-gram range.
77    pub fn ngram_range(mut self, min_n: usize, max_n: usize) -> Self {
78        self.count = self.count.ngram_range(min_n, max_n);
79        self
80    }
81
82    /// Limit vocabulary size.
83    pub fn max_features(mut self, n: usize) -> Self {
84        self.count = self.count.max_features(n);
85        self
86    }
87
88    /// Set normalization method. Default: L2.
89    pub fn norm(mut self, norm: TfidfNorm) -> Self {
90        self.norm = norm;
91        self
92    }
93
94    /// Enable sublinear TF scaling: `tf = 1 + log(tf)`.
95    pub fn sublinear_tf(mut self, enable: bool) -> Self {
96        self.sublinear_tf = enable;
97        self
98    }
99
100    /// Enable smooth IDF: adds 1 to document frequencies.
101    /// Default: true (matches sklearn).
102    pub fn smooth_idf(mut self, enable: bool) -> Self {
103        self.smooth_idf = enable;
104        self
105    }
106
107    /// Learn vocabulary and IDF weights from documents.
108    pub fn fit<S: AsRef<str>>(&mut self, documents: &[S]) {
109        self.count.fit(documents);
110
111        let n_docs = documents.len();
112        let n_features = self.count.n_features();
113
114        // Compute document frequency for each term.
115        let mut doc_freq = vec![0usize; n_features];
116        let vocab = self.count.vocabulary();
117
118        for doc in documents {
119            let grams = self.count.tokenize_doc(doc.as_ref());
120            let mut seen = std::collections::HashSet::new();
121            for gram in &grams {
122                if let Some(&idx) = vocab.get(gram) {
123                    if seen.insert(idx) {
124                        doc_freq[idx] += 1;
125                    }
126                }
127            }
128        }
129
130        // Compute IDF.
131        self.idf_values = vec![0.0; n_features];
132        let smooth = if self.smooth_idf { 1.0 } else { 0.0 };
133        let n = n_docs as f64 + smooth;
134
135        for (i, &df) in doc_freq.iter().enumerate() {
136            let df_smooth = df as f64 + smooth;
137            self.idf_values[i] = (n / df_smooth).ln() + 1.0;
138        }
139
140        self.fitted = true;
141    }
142
143    /// Transform documents into a TF-IDF weighted sparse matrix.
144    pub fn transform<S: AsRef<str>>(&self, documents: &[S]) -> CsrMatrix {
145        assert!(
146            self.fitted,
147            "TfidfVectorizer: must call fit() before transform()"
148        );
149
150        let counts = self.count.transform(documents);
151        let n_rows = counts.n_rows();
152        let n_cols = counts.n_cols();
153
154        if n_rows == 0 || n_cols == 0 {
155            return CsrMatrix::from_dense(&[]);
156        }
157
158        // Get dense counts, apply TF-IDF weighting, then rebuild as sparse.
159        let count_dense = counts.to_dense();
160
161        let mut triplet_rows = Vec::new();
162        let mut triplet_cols = Vec::new();
163        let mut triplet_vals = Vec::new();
164
165        for (row_idx, row) in count_dense.iter().enumerate() {
166            let mut row_entries: Vec<(usize, f64)> = Vec::new();
167
168            for (col, &count) in row.iter().enumerate() {
169                if count == 0.0 {
170                    continue;
171                }
172
173                let tf = if self.sublinear_tf {
174                    1.0 + count.ln()
175                } else {
176                    count
177                };
178
179                let idf = self.idf_values.get(col).copied().unwrap_or(1.0);
180                let tfidf = tf * idf;
181                row_entries.push((col, tfidf));
182            }
183
184            // Normalize.
185            if !row_entries.is_empty() {
186                match self.norm {
187                    TfidfNorm::L2 => {
188                        let norm: f64 = row_entries.iter().map(|(_, v)| v * v).sum::<f64>().sqrt();
189                        if norm > 0.0 {
190                            for entry in &mut row_entries {
191                                entry.1 /= norm;
192                            }
193                        }
194                    }
195                    TfidfNorm::L1 => {
196                        let norm: f64 = row_entries.iter().map(|(_, v)| v.abs()).sum();
197                        if norm > 0.0 {
198                            for entry in &mut row_entries {
199                                entry.1 /= norm;
200                            }
201                        }
202                    }
203                    TfidfNorm::None => {}
204                }
205            }
206
207            for (col, val) in row_entries {
208                triplet_rows.push(row_idx);
209                triplet_cols.push(col);
210                triplet_vals.push(val);
211            }
212        }
213
214        CsrMatrix::from_triplets(&triplet_rows, &triplet_cols, &triplet_vals, n_rows, n_cols)
215            .expect("TfidfVectorizer: internal CSR construction error")
216    }
217
218    /// Fit and transform in one step.
219    pub fn fit_transform<S: AsRef<str>>(&mut self, documents: &[S]) -> CsrMatrix {
220        self.fit(documents);
221        self.transform(documents)
222    }
223
224    /// Return the learned IDF weights.
225    pub fn idf(&self) -> &[f64] {
226        &self.idf_values
227    }
228
229    /// Return the underlying vocabulary.
230    pub fn vocabulary(&self) -> &std::collections::HashMap<String, usize> {
231        self.count.vocabulary()
232    }
233
234    /// Return feature names sorted by column index.
235    pub fn get_feature_names(&self) -> Vec<String> {
236        self.count.get_feature_names()
237    }
238
239    /// Number of features.
240    pub fn n_features(&self) -> usize {
241        self.count.n_features()
242    }
243}
244
245impl Default for TfidfVectorizer {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn basic_fit_transform() {
257        let docs = ["the cat sat", "the dog sat", "the cat played"];
258        let mut tfidf = TfidfVectorizer::new();
259        let matrix = tfidf.fit_transform(&docs);
260
261        assert_eq!(matrix.n_rows(), 3);
262        assert_eq!(matrix.n_cols(), tfidf.n_features());
263        assert_eq!(tfidf.n_features(), 5); // the, cat, dog, sat, played
264    }
265
266    #[test]
267    fn idf_values_are_positive() {
268        let docs = ["hello world", "hello test"];
269        let mut tfidf = TfidfVectorizer::new();
270        tfidf.fit(&docs);
271
272        for &idf in tfidf.idf() {
273            assert!(idf > 0.0, "IDF should be positive, got {idf}");
274        }
275    }
276
277    #[test]
278    fn l2_normalization() {
279        let docs = ["a b c", "a b b"];
280        let mut tfidf = TfidfVectorizer::new().norm(TfidfNorm::L2);
281        let matrix = tfidf.fit_transform(&docs);
282        let dense = matrix.to_dense();
283
284        for row in &dense {
285            let norm: f64 = row.iter().map(|v| v * v).sum::<f64>().sqrt();
286            if norm > 0.0 {
287                assert!(
288                    (norm - 1.0).abs() < 1e-10,
289                    "L2 norm should be 1.0, got {norm}"
290                );
291            }
292        }
293    }
294
295    #[test]
296    fn l1_normalization() {
297        let docs = ["a b c"];
298        let mut tfidf = TfidfVectorizer::new().norm(TfidfNorm::L1);
299        let matrix = tfidf.fit_transform(&docs);
300        let dense = matrix.to_dense();
301
302        let norm: f64 = dense[0].iter().map(|v| v.abs()).sum();
303        assert!(
304            (norm - 1.0).abs() < 1e-10,
305            "L1 norm should be 1.0, got {norm}"
306        );
307    }
308
309    #[test]
310    fn no_normalization() {
311        let docs = ["a a"];
312        let mut tfidf = TfidfVectorizer::new().norm(TfidfNorm::None);
313        let matrix = tfidf.fit_transform(&docs);
314        let dense = matrix.to_dense();
315
316        // tf=2, idf=ln(1+1/1+1)+1 = ln(1)+1 = 1.0 with smooth_idf
317        // So tfidf = 2 * 1.0 = 2.0
318        assert!(
319            dense[0].iter().any(|&v| v > 1.0),
320            "Expected unnormalized values"
321        );
322    }
323
324    #[test]
325    fn smooth_idf_default() {
326        let docs = ["a", "b"];
327        let mut tfidf = TfidfVectorizer::new();
328        tfidf.fit(&docs);
329
330        // With smooth_idf: idf = ln((n+1)/(df+1)) + 1
331        // For "a": idf = ln(3/2) + 1 ≈ 1.405
332        for &idf in tfidf.idf() {
333            assert!(idf > 1.0, "Smooth IDF should be > 1.0, got {idf}");
334        }
335    }
336
337    #[test]
338    fn sublinear_tf() {
339        let docs = ["a a a a a"];
340        let mut tfidf = TfidfVectorizer::new()
341            .sublinear_tf(true)
342            .norm(TfidfNorm::None);
343        let matrix = tfidf.fit_transform(&docs);
344        let dense = matrix.to_dense();
345
346        // With sublinear_tf: tf = 1 + ln(5) ≈ 2.609
347        // Without: tf = 5
348        // So the value should be less than 5 * idf
349        let val = dense[0].iter().find(|&&v| v > 0.0).unwrap();
350        // idf = ln(2/2) + 1 = 1.0 with smooth
351        assert!(*val < 5.0, "Sublinear TF should reduce high counts");
352    }
353
354    #[test]
355    fn unseen_terms_ignored() {
356        let train = ["cat dog"];
357        let test = ["cat bird"]; // "bird" not in vocabulary
358
359        let mut tfidf = TfidfVectorizer::new();
360        tfidf.fit(&train);
361
362        let matrix = tfidf.transform(&test);
363        let dense = matrix.to_dense();
364
365        let nnz: usize = dense[0].iter().filter(|&&v| v > 0.0).count();
366        assert_eq!(nnz, 1, "Only 'cat' should have a non-zero value");
367    }
368
369    #[test]
370    fn bigrams_tfidf() {
371        let docs = ["the cat sat"];
372        let mut tfidf = TfidfVectorizer::new().ngram_range(2, 2);
373        let matrix = tfidf.fit_transform(&docs);
374
375        assert_eq!(matrix.n_cols(), 2); // "the cat", "cat sat"
376    }
377
378    #[test]
379    fn empty_documents() {
380        let docs: [&str; 0] = [];
381        let mut tfidf = TfidfVectorizer::new();
382        let matrix = tfidf.fit_transform(&docs);
383
384        assert_eq!(matrix.n_rows(), 0);
385    }
386}