1use crate::dataset::Dataset;
28use crate::error::{Result, ScryLearnError};
29use crate::sparse::{CscMatrix, CsrMatrix};
30use crate::weights::{compute_sample_weights, ClassWeight};
31
32#[derive(Clone)]
39#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40#[non_exhaustive]
41pub struct MultinomialNB {
42 alpha: f64,
44 class_weight: ClassWeight,
46 log_probs: Vec<Vec<f64>>,
48 log_priors: Vec<f64>,
50 n_classes: usize,
51 fitted: bool,
52 #[cfg_attr(feature = "serde", serde(default))]
53 _schema_version: u32,
54}
55
56impl MultinomialNB {
57 pub fn new() -> Self {
59 Self {
60 alpha: 1.0,
61 class_weight: ClassWeight::Uniform,
62 log_probs: Vec::new(),
63 log_priors: Vec::new(),
64 n_classes: 0,
65 fitted: false,
66 _schema_version: crate::version::SCHEMA_VERSION,
67 }
68 }
69
70 pub fn alpha(mut self, a: f64) -> Self {
72 self.alpha = a;
73 self
74 }
75
76 pub fn class_weight(mut self, cw: ClassWeight) -> Self {
78 self.class_weight = cw;
79 self
80 }
81
82 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
86 data.validate_finite()?;
87 let n = data.n_samples();
88 let m = data.n_features();
89 if n == 0 {
90 return Err(ScryLearnError::EmptyDataset);
91 }
92
93 self.n_classes = data.n_classes();
94 let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
95
96 let mut feature_sum = vec![vec![0.0_f64; m]; self.n_classes];
98 let mut class_weight_sum = vec![0.0_f64; self.n_classes];
99
100 for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate() {
101 let c = target_val as usize;
102 if c >= self.n_classes {
103 continue;
104 }
105 class_weight_sum[c] += sw;
106 for (j, feat_col) in data.features.iter().enumerate() {
107 feature_sum[c][j] += sw * feat_col[i];
108 }
109 }
110
111 self.log_probs = vec![vec![0.0; m]; self.n_classes];
114 for (c_probs, c_sums) in self.log_probs.iter_mut().zip(feature_sum.iter()) {
115 let total: f64 = c_sums.iter().sum::<f64>() + self.alpha * m as f64;
116 for (lp, &fs) in c_probs.iter_mut().zip(c_sums.iter()) {
117 *lp = ((fs + self.alpha) / total).ln();
118 }
119 }
120
121 let total_weight: f64 = class_weight_sum.iter().sum();
123 self.log_priors = class_weight_sum
124 .iter()
125 .map(|&w| (w / total_weight).ln())
126 .collect();
127
128 self.fitted = true;
129 Ok(())
130 }
131
132 #[allow(clippy::needless_range_loop)]
136 pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
137 let n = features.n_rows();
138 let m = features.n_cols();
139 if n == 0 {
140 return Err(ScryLearnError::EmptyDataset);
141 }
142 if target.len() != n {
143 return Err(ScryLearnError::InvalidParameter(format!(
144 "target length {} != n_rows {}",
145 target.len(),
146 n
147 )));
148 }
149
150 let max_class = target.iter().map(|&t| t as usize).max().unwrap_or(0);
151 self.n_classes = max_class + 1;
152 let sample_weights = compute_sample_weights(target, &self.class_weight);
153
154 let mut feature_sum = vec![vec![0.0_f64; m]; self.n_classes];
155 let mut class_weight_sum = vec![0.0_f64; self.n_classes];
156
157 for (&sw, &t) in sample_weights.iter().zip(target.iter()) {
158 let c = t as usize;
159 if c < self.n_classes {
160 class_weight_sum[c] += sw;
161 }
162 }
163
164 for j in 0..m {
166 for (row_idx, val) in features.col(j).iter() {
167 let c = target[row_idx] as usize;
168 if c < self.n_classes {
169 feature_sum[c][j] += sample_weights[row_idx] * val;
170 }
171 }
172 }
173
174 self.log_probs = vec![vec![0.0; m]; self.n_classes];
176 for (c_probs, c_sums) in self.log_probs.iter_mut().zip(feature_sum.iter()) {
177 let total: f64 = c_sums.iter().sum::<f64>() + self.alpha * m as f64;
178 for (lp, &fs) in c_probs.iter_mut().zip(c_sums.iter()) {
179 *lp = ((fs + self.alpha) / total).ln();
180 }
181 }
182
183 let total_weight: f64 = class_weight_sum.iter().sum();
184 self.log_priors = class_weight_sum
185 .iter()
186 .map(|&w| (w / total_weight).ln())
187 .collect();
188
189 self.fitted = true;
190 Ok(())
191 }
192
193 pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
195 if !self.fitted {
196 return Err(ScryLearnError::NotFitted);
197 }
198 let probas = self.predict_proba_sparse(features)?;
199 Ok(probas
200 .iter()
201 .map(|probs| {
202 probs
203 .iter()
204 .enumerate()
205 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
206 .map_or(0.0, |(idx, _)| idx as f64)
207 })
208 .collect())
209 }
210
211 pub fn predict_proba_sparse(&self, features: &CsrMatrix) -> Result<Vec<Vec<f64>>> {
215 if !self.fitted {
216 return Err(ScryLearnError::NotFitted);
217 }
218 Ok((0..features.n_rows())
219 .map(|i| {
220 let row = features.row(i);
221 let mut log_probs: Vec<f64> = (0..self.n_classes)
222 .map(|c| {
223 let mut lp = self.log_priors[c];
224 for (col, val) in row.iter() {
226 if col < self.log_probs[c].len() {
227 lp += val * self.log_probs[c][col];
228 }
229 }
230 lp
231 })
232 .collect();
233
234 let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
235 let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
236 for lp in &mut log_probs {
237 *lp = ((*lp - max_log).exp()) / sum;
238 }
239 log_probs
240 })
241 .collect())
242 }
243
244 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
246 crate::version::check_schema_version(self._schema_version)?;
247 if !self.fitted {
248 return Err(ScryLearnError::NotFitted);
249 }
250 let probas = self.predict_proba(features)?;
251 Ok(probas
252 .iter()
253 .map(|probs| {
254 probs
255 .iter()
256 .enumerate()
257 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
258 .map_or(0.0, |(idx, _)| idx as f64)
259 })
260 .collect())
261 }
262
263 pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
265 if !self.fitted {
266 return Err(ScryLearnError::NotFitted);
267 }
268
269 Ok(features
270 .iter()
271 .map(|row| {
272 let mut log_probs: Vec<f64> = (0..self.n_classes)
273 .map(|c| {
274 let mut lp = self.log_priors[c];
275 for (j, &x) in row.iter().enumerate() {
276 if j >= self.log_probs[c].len() {
277 continue;
278 }
279 lp += x * self.log_probs[c][j];
281 }
282 lp
283 })
284 .collect();
285
286 let max_log = log_probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
288 let sum: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
289 for lp in &mut log_probs {
290 *lp = ((*lp - max_log).exp()) / sum;
291 }
292 log_probs
293 })
294 .collect())
295 }
296}
297
298impl Default for MultinomialNB {
299 fn default() -> Self {
300 Self::new()
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_multinomial_nb_counts() {
310 let features = vec![
312 vec![5.0, 6.0, 4.0, 0.0, 1.0, 0.0],
313 vec![0.0, 1.0, 0.0, 5.0, 6.0, 4.0],
314 ];
315 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
316 let data = Dataset::new(
317 features,
318 target,
319 vec!["word_a".into(), "word_b".into()],
320 "class",
321 );
322
323 let mut nb = MultinomialNB::new();
324 nb.fit(&data).unwrap();
325
326 let preds = nb.predict(&[vec![4.0, 0.0], vec![0.0, 5.0]]).unwrap();
327 assert!((preds[0] - 0.0).abs() < 1e-6, "high word_a → class 0");
328 assert!((preds[1] - 1.0).abs() < 1e-6, "high word_b → class 1");
329 }
330
331 #[test]
332 fn test_multinomial_nb_predict_proba() {
333 let features = vec![vec![5.0, 5.0, 0.0, 0.0], vec![0.0, 0.0, 5.0, 5.0]];
334 let target = vec![0.0, 0.0, 1.0, 1.0];
335 let data = Dataset::new(features, target, vec!["f0".into(), "f1".into()], "class");
336
337 let mut nb = MultinomialNB::new();
338 nb.fit(&data).unwrap();
339
340 let probas = nb.predict_proba(&[vec![4.0, 0.0]]).unwrap();
341 assert_eq!(probas[0].len(), 2);
342 let sum: f64 = probas[0].iter().sum();
343 assert!(
344 (sum - 1.0).abs() < 1e-9,
345 "probabilities must sum to 1.0, got {sum}"
346 );
347 }
348
349 #[test]
350 fn test_sparse_multinomial_nb_matches_dense() {
351 let features = vec![
352 vec![5.0, 6.0, 4.0, 0.0, 1.0, 0.0],
353 vec![0.0, 1.0, 0.0, 5.0, 6.0, 4.0],
354 ];
355 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
356 let data = Dataset::new(
357 features.clone(),
358 target.clone(),
359 vec!["w_a".into(), "w_b".into()],
360 "class",
361 );
362
363 let mut nb_dense = MultinomialNB::new();
364 nb_dense.fit(&data).unwrap();
365
366 let csc = CscMatrix::from_dense(&features);
367 let mut nb_sparse = MultinomialNB::new();
368 nb_sparse.fit_sparse(&csc, &target).unwrap();
369
370 let test = vec![vec![4.0, 0.0], vec![0.0, 5.0]];
371 let preds_dense = nb_dense.predict(&test).unwrap();
372 let csr = CsrMatrix::from_dense(&test);
373 let preds_sparse = nb_sparse.predict_sparse(&csr).unwrap();
374
375 for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
376 assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
377 }
378 }
379}