Skip to main content

scry_learn/naive_bayes/
multinomial.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Multinomial Naive Bayes classifier for count/frequency features.
3//!
4//! Suited for text classification with bag-of-words count vectors,
5//! TF-IDF features, or any non-negative count data.
6//!
7//! # Example
8//!
9//! ```
10//! use scry_learn::naive_bayes::MultinomialNB;
11//! use scry_learn::dataset::Dataset;
12//!
13//! // Simulated word counts: [word_a_count, word_b_count].
14//! let data = Dataset::new(
15//!     vec![vec![5.0, 5.0, 0.0, 0.0], vec![0.0, 0.0, 5.0, 5.0]],
16//!     vec![0.0, 0.0, 1.0, 1.0],
17//!     vec!["word_a".into(), "word_b".into()],
18//!     "category",
19//! );
20//!
21//! let mut nb = MultinomialNB::new();
22//! nb.fit(&data).unwrap();
23//! let preds = nb.predict(&[vec![4.0, 0.0]]).unwrap();
24//! assert!((preds[0] - 0.0).abs() < 1e-6);
25//! ```
26
27use crate::dataset::Dataset;
28use crate::error::{Result, ScryLearnError};
29use crate::sparse::{CscMatrix, CsrMatrix};
30use crate::weights::{compute_sample_weights, ClassWeight};
31
32/// Multinomial Naive Bayes — for count/frequency features.
33///
34/// Models each class as a multinomial distribution over features.
35/// Well-suited for document classification with term frequencies.
36///
37/// Uses Laplace smoothing (additive smoothing) to handle zero counts.
38#[derive(Clone)]
39#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40#[non_exhaustive]
41pub struct MultinomialNB {
42    /// Laplace smoothing parameter.
43    alpha: f64,
44    /// Class weighting strategy.
45    class_weight: ClassWeight,
46    /// Log-probabilities of features given class: `log_probs[class][feature]`.
47    log_probs: Vec<Vec<f64>>,
48    /// Log prior probabilities per class.
49    log_priors: Vec<f64>,
50    n_classes: usize,
51    fitted: bool,
52    #[cfg_attr(feature = "serde", serde(default))]
53    _schema_version: u32,
54}
55
56impl MultinomialNB {
57    /// Create a new Multinomial Naive Bayes classifier.
58    pub fn new() -> Self {
59        Self {
60            alpha: 1.0,
61            class_weight: ClassWeight::Uniform,
62            log_probs: Vec::new(),
63            log_priors: Vec::new(),
64            n_classes: 0,
65            fitted: false,
66            _schema_version: crate::version::SCHEMA_VERSION,
67        }
68    }
69
70    /// Set Laplace smoothing parameter (default 1.0).
71    pub fn alpha(mut self, a: f64) -> Self {
72        self.alpha = a;
73        self
74    }
75
76    /// Set class weighting strategy for imbalanced datasets.
77    pub fn class_weight(mut self, cw: ClassWeight) -> Self {
78        self.class_weight = cw;
79        self
80    }
81
82    /// Train the model on a dataset.
83    ///
84    /// Features should be non-negative counts or frequencies.
85    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
86        data.validate_finite()?;
87        let n = data.n_samples();
88        let m = data.n_features();
89        if n == 0 {
90            return Err(ScryLearnError::EmptyDataset);
91        }
92
93        self.n_classes = data.n_classes();
94        let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
95
96        // Compute weighted feature sums per class.
97        let mut feature_sum = vec![vec![0.0_f64; m]; self.n_classes];
98        let mut class_weight_sum = vec![0.0_f64; self.n_classes];
99
100        for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate() {
101            let c = target_val as usize;
102            if c >= self.n_classes {
103                continue;
104            }
105            class_weight_sum[c] += sw;
106            for (j, feat_col) in data.features.iter().enumerate() {
107                feature_sum[c][j] += sw * feat_col[i];
108            }
109        }
110
111        // Compute smoothed log-probabilities.
112        // P(x_j | c) = (sum_jc + alpha) / (sum_c + n_features * alpha)
113        self.log_probs = vec![vec![0.0; m]; self.n_classes];
114        for (c_probs, c_sums) in self.log_probs.iter_mut().zip(feature_sum.iter()) {
115            let total: f64 = c_sums.iter().sum::<f64>() + self.alpha * m as f64;
116            for (lp, &fs) in c_probs.iter_mut().zip(c_sums.iter()) {
117                *lp = ((fs + self.alpha) / total).ln();
118            }
119        }
120
121        // Log priors.
122        let total_weight: f64 = class_weight_sum.iter().sum();
123        self.log_priors = class_weight_sum
124            .iter()
125            .map(|&w| (w / total_weight).ln())
126            .collect();
127
128        self.fitted = true;
129        Ok(())
130    }
131
132    /// Fit on sparse features (CSC format) — perfect for TF-IDF.
133    ///
134    /// Sums non-zero entries per class per feature for count-based likelihood.
135    #[allow(clippy::needless_range_loop)]
136    pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
137        let n = features.n_rows();
138        let m = features.n_cols();
139        if n == 0 {
140            return Err(ScryLearnError::EmptyDataset);
141        }
142        if target.len() != n {
143            return Err(ScryLearnError::InvalidParameter(format!(
144                "target length {} != n_rows {}",
145                target.len(),
146                n
147            )));
148        }
149
150        let max_class = target.iter().map(|&t| t as usize).max().unwrap_or(0);
151        self.n_classes = max_class + 1;
152        let sample_weights = compute_sample_weights(target, &self.class_weight);
153
154        let mut feature_sum = vec![vec![0.0_f64; m]; self.n_classes];
155        let mut class_weight_sum = vec![0.0_f64; self.n_classes];
156
157        for (&sw, &t) in sample_weights.iter().zip(target.iter()) {
158            let c = t as usize;
159            if c < self.n_classes {
160                class_weight_sum[c] += sw;
161            }
162        }
163
164        // Sum features per class using sparse column iteration.
165        for j in 0..m {
166            for (row_idx, val) in features.col(j).iter() {
167                let c = target[row_idx] as usize;
168                if c < self.n_classes {
169                    feature_sum[c][j] += sample_weights[row_idx] * val;
170                }
171            }
172        }
173
174        // Smoothed log-probabilities.
175        self.log_probs = vec![vec![0.0; m]; self.n_classes];
176        for (c_probs, c_sums) in self.log_probs.iter_mut().zip(feature_sum.iter()) {
177            let total: f64 = c_sums.iter().sum::<f64>() + self.alpha * m as f64;
178            for (lp, &fs) in c_probs.iter_mut().zip(c_sums.iter()) {
179                *lp = ((fs + self.alpha) / total).ln();
180            }
181        }
182
183        let total_weight: f64 = class_weight_sum.iter().sum();
184        self.log_priors = class_weight_sum
185            .iter()
186            .map(|&w| (w / total_weight).ln())
187            .collect();
188
189        self.fitted = true;
190        Ok(())
191    }
192
193    /// Predict from sparse features (CSR format).
194    pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
195        if !self.fitted {
196            return Err(ScryLearnError::NotFitted);
197        }
198        let probas = self.predict_proba_sparse(features)?;
199        Ok(probas
200            .iter()
201            .map(|probs| {
202                probs
203                    .iter()
204                    .enumerate()
205                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
206                    .map_or(0.0, |(idx, _)| idx as f64)
207            })
208            .collect())
209    }
210
211    /// Predict probabilities from sparse features (CSR format).
212    ///
213    /// Only accumulates log-probability for non-zero features (zero × log_prob = 0).
214    pub fn predict_proba_sparse(&self, features: &CsrMatrix) -> Result<Vec<Vec<f64>>> {
215        if !self.fitted {
216            return Err(ScryLearnError::NotFitted);
217        }
218        Ok((0..features.n_rows())
219            .map(|i| {
220                let row = features.row(i);
221                let mut log_probs: Vec<f64> = (0..self.n_classes)
222                    .map(|c| {
223                        let mut lp = self.log_priors[c];
224                        // Only non-zero features contribute: x_j * log P(x_j | c).
225                        for (col, val) in row.iter() {
226                            if col < self.log_probs[c].len() {
227                                lp += val * self.log_probs[c][col];
228                            }
229                        }
230                        lp
231                    })
232                    .collect();
233
234                let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
235                let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
236                for lp in &mut log_probs {
237                    *lp = ((*lp - max_log).exp()) / sum;
238                }
239                log_probs
240            })
241            .collect())
242    }
243
244    /// Predict class labels.
245    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
246        crate::version::check_schema_version(self._schema_version)?;
247        if !self.fitted {
248            return Err(ScryLearnError::NotFitted);
249        }
250        let probas = self.predict_proba(features)?;
251        Ok(probas
252            .iter()
253            .map(|probs| {
254                probs
255                    .iter()
256                    .enumerate()
257                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
258                    .map_or(0.0, |(idx, _)| idx as f64)
259            })
260            .collect())
261    }
262
263    /// Predict normalized probabilities for each class.
264    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
265        if !self.fitted {
266            return Err(ScryLearnError::NotFitted);
267        }
268
269        Ok(features
270            .iter()
271            .map(|row| {
272                let mut log_probs: Vec<f64> = (0..self.n_classes)
273                    .map(|c| {
274                        let mut lp = self.log_priors[c];
275                        for (j, &x) in row.iter().enumerate() {
276                            if j >= self.log_probs[c].len() {
277                                continue;
278                            }
279                            // Multinomial likelihood: x_j * log P(x_j | c).
280                            lp += x * self.log_probs[c][j];
281                        }
282                        lp
283                    })
284                    .collect();
285
286                // Log-sum-exp for numerical stability.
287                let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
288                let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
289                for lp in &mut log_probs {
290                    *lp = ((*lp - max_log).exp()) / sum;
291                }
292                log_probs
293            })
294            .collect())
295    }
296}
297
298impl Default for MultinomialNB {
299    fn default() -> Self {
300        Self::new()
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_multinomial_nb_counts() {
310        // Class 0: high word_a counts, Class 1: high word_b counts.
311        let features = vec![
312            vec![5.0, 6.0, 4.0, 0.0, 1.0, 0.0],
313            vec![0.0, 1.0, 0.0, 5.0, 6.0, 4.0],
314        ];
315        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
316        let data = Dataset::new(
317            features,
318            target,
319            vec!["word_a".into(), "word_b".into()],
320            "class",
321        );
322
323        let mut nb = MultinomialNB::new();
324        nb.fit(&data).unwrap();
325
326        let preds = nb.predict(&[vec![4.0, 0.0], vec![0.0, 5.0]]).unwrap();
327        assert!((preds[0] - 0.0).abs() < 1e-6, "high word_a → class 0");
328        assert!((preds[1] - 1.0).abs() < 1e-6, "high word_b → class 1");
329    }
330
331    #[test]
332    fn test_multinomial_nb_predict_proba() {
333        let features = vec![vec![5.0, 5.0, 0.0, 0.0], vec![0.0, 0.0, 5.0, 5.0]];
334        let target = vec![0.0, 0.0, 1.0, 1.0];
335        let data = Dataset::new(features, target, vec!["f0".into(), "f1".into()], "class");
336
337        let mut nb = MultinomialNB::new();
338        nb.fit(&data).unwrap();
339
340        let probas = nb.predict_proba(&[vec![4.0, 0.0]]).unwrap();
341        assert_eq!(probas[0].len(), 2);
342        let sum: f64 = probas[0].iter().sum();
343        assert!(
344            (sum - 1.0).abs() < 1e-9,
345            "probabilities must sum to 1.0, got {sum}"
346        );
347    }
348
349    #[test]
350    fn test_sparse_multinomial_nb_matches_dense() {
351        let features = vec![
352            vec![5.0, 6.0, 4.0, 0.0, 1.0, 0.0],
353            vec![0.0, 1.0, 0.0, 5.0, 6.0, 4.0],
354        ];
355        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
356        let data = Dataset::new(
357            features.clone(),
358            target.clone(),
359            vec!["w_a".into(), "w_b".into()],
360            "class",
361        );
362
363        let mut nb_dense = MultinomialNB::new();
364        nb_dense.fit(&data).unwrap();
365
366        let csc = CscMatrix::from_dense(&features);
367        let mut nb_sparse = MultinomialNB::new();
368        nb_sparse.fit_sparse(&csc, &target).unwrap();
369
370        let test = vec![vec![4.0, 0.0], vec![0.0, 5.0]];
371        let preds_dense = nb_dense.predict(&test).unwrap();
372        let csr = CsrMatrix::from_dense(&test);
373        let preds_sparse = nb_sparse.predict_sparse(&csr).unwrap();
374
375        for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
376            assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
377        }
378    }
379}