1use crate::{
2 data::dataset::{Dataset, RealNumber, WholeNumber},
3 metrics::confusion::ClassificationMetrics,
4};
5use nalgebra::{DMatrix, DVector};
6use std::{
7 collections::{HashMap, HashSet},
8 error::Error,
9};
10
11#[derive(Clone, Debug)]
48pub struct GaussianNB<XT: RealNumber, YT: WholeNumber> {
49 class_freq: HashMap<YT, XT>,
50 class_mean: HashMap<YT, DVector<XT>>,
51 class_variance: HashMap<YT, DVector<XT>>,
52}
53
54impl<XT: RealNumber, YT: WholeNumber> ClassificationMetrics<YT> for GaussianNB<XT, YT> {}
55
56impl<XT: RealNumber, YT: WholeNumber> Default for GaussianNB<XT, YT> {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl<XT: RealNumber, YT: WholeNumber> GaussianNB<XT, YT> {
63 pub fn new() -> Self {
72 Self {
73 class_freq: HashMap::new(),
74 class_mean: HashMap::new(),
75 class_variance: HashMap::new(),
76 }
77 }
78
79 pub fn class_freq(&self) -> &HashMap<YT, XT> {
88 &self.class_freq
89 }
90
91 pub fn class_mean(&self) -> &HashMap<YT, DVector<XT>> {
100 &self.class_mean
101 }
102
103 pub fn class_variance(&self) -> &HashMap<YT, DVector<XT>> {
112 &self.class_variance
113 }
114
115 pub fn fit(&mut self, dataset: &Dataset<XT, YT>) -> Result<String, Box<dyn Error>> {
128 let (x, y) = dataset.into_parts();
129 let classes = y.iter().cloned().collect::<HashSet<_>>();
130
131 for class in classes {
132 let class_mask = y.map(|label| label == class);
133 let class_indices = class_mask
134 .iter()
135 .enumerate()
136 .filter(|&(_, &value)| value)
137 .map(|(index, _)| index)
138 .collect::<Vec<_>>();
139 let x_class = x.select_rows(class_indices.as_slice());
140
141 let mean = DVector::from_fn(x_class.ncols(), |col, _| {
142 self.mean(&x_class.column(col).into_owned())
143 });
144 let variance = DVector::from_fn(x_class.ncols(), |col, _| {
145 self.variance(&x_class.column(col).into_owned())
146 });
147
148 let freq =
149 XT::from_usize(class_indices.len()).unwrap() / XT::from_usize(x.nrows()).unwrap();
150
151 self.class_freq.insert(class, freq);
152 self.class_mean.insert(class, mean);
153 self.class_variance.insert(class, variance);
154 }
155 Ok("Finished fitting".into())
156 }
157
158 fn mean(&self, x: &DVector<XT>) -> XT {
159 let zero = XT::from_f64(0.0).unwrap();
160 let sum: XT = x.fold(zero, |acc, x| acc + x);
161
162 sum / XT::from_usize(x.len()).unwrap()
163 }
164
165 fn variance(&self, x: &DVector<XT>) -> XT {
166 let mean = self.mean(x);
167 let zero = XT::from_f64(0.0).unwrap();
168 let numerator = x.fold(zero, |acc, x| acc + (x - mean) * (x - mean));
169
170 numerator / XT::from_usize(x.len() - 1).unwrap()
171 }
172
173 fn predict_single(&self, x: &DVector<XT>) -> Result<YT, Box<dyn Error>> {
174 let mut max_log_likelihood = XT::from_f64(f64::NEG_INFINITY).unwrap();
175 let mut max_class = YT::from_i8(0).unwrap();
176
177 for class in self.class_freq.keys() {
178 let mean = self
179 .class_mean
180 .get(class)
181 .ok_or(format!("Mean for class {:?} wasn't calculated.", class))?;
182 let variance = self
183 .class_variance
184 .get(class)
185 .ok_or(format!("Variance for class {:?} wasn't calculated.", class))?;
186 let variance_epsilon =
187 DVector::<XT>::from_element(variance.len(), XT::from_f64(1e-9).unwrap());
188
189 let starting = XT::from_f64(-0.5).unwrap();
190 let log_likelihood = starting
191 * ((x - mean).component_mul(&(x - mean)).component_div(
192 &(variance.map(|v| v * XT::from_f64(2.0).unwrap()) + &variance_epsilon),
193 ))
194 .sum()
195 + starting * (variance + &variance_epsilon).map(|v| v.ln()).sum()
196 + self
197 .class_freq
198 .get(class)
199 .ok_or(format!("Frequency of class {:?} wasn't obtained.", class))?
200 .ln();
201
202 if log_likelihood > max_log_likelihood {
203 max_log_likelihood = log_likelihood;
204 max_class = *class;
205 }
206 }
207 Ok(max_class)
208 }
209
210 pub fn predict(&self, x: &DMatrix<XT>) -> Result<DVector<YT>, Box<dyn Error>> {
224 let mut y_pred = Vec::new();
225
226 for i in 0..x.nrows() {
227 let x_row = x.row(i).into_owned().transpose();
228 let class = self.predict_single(&x_row)?;
229 y_pred.push(class);
230 }
231
232 Ok(DVector::from_vec(y_pred))
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use approx::assert_abs_diff_eq;
240
241 #[test]
242 fn test_new() {
243 let clf = GaussianNB::<f64, i32>::new();
244
245 assert!(clf.class_freq.is_empty());
246 assert!(clf.class_mean.is_empty());
247 assert!(clf.class_variance.is_empty());
248 }
249
250 #[test]
251 fn test_model_fit() {
252 let mut clf = GaussianNB::<f64, i32>::new();
253
254 let x = DMatrix::from_row_slice(
255 4,
256 3,
257 &[
258 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
259 ],
260 );
261 let y = DVector::from_column_slice(&[0, 0, 1, 1]);
262 let dataset = Dataset::new(x, y);
263
264 let _ = clf.fit(&dataset);
265
266 assert_abs_diff_eq!(*clf.class_freq.get(&0).unwrap(), 0.5, epsilon = 1e-7);
267 assert_abs_diff_eq!(*clf.class_freq.get(&1).unwrap(), 0.5, epsilon = 1e-7);
268 }
269
270 #[test]
271 fn test_predictions() {
272 let mut clf = GaussianNB::<f64, i32>::new();
273
274 let x = DMatrix::from_row_slice(
275 4,
276 3,
277 &[
278 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
279 ],
280 );
281 let y = DVector::from_column_slice(&[0, 0, 1, 1]);
282 let dataset = Dataset::new(x, y);
283
284 let _ = clf.fit(&dataset);
285
286 let test_x = DMatrix::from_row_slice(2, 3, &[2.0, 3.0, 4.0, 6.0, 7.0, 8.0]);
287
288 let pred_y = clf.predict(&test_x).unwrap();
289
290 assert_eq!(pred_y, DVector::from_column_slice(&[0, 1]));
291 }
292
293 #[test]
294 fn test_empty_data() {
295 let mut clf = GaussianNB::<f64, i32>::new();
296 let empty_x = DMatrix::<f64>::zeros(0, 0);
297 let empty_y = DVector::<i32>::zeros(0);
298 let empty_pred_y = clf.predict(&empty_x).unwrap();
299 assert_eq!(empty_pred_y.len(), 0);
300 let dataset = Dataset::new(empty_x, empty_y);
301
302 let _ = clf.fit(&dataset);
303 assert_eq!(clf.class_freq.len(), 0);
304 assert_eq!(clf.class_mean.len(), 0);
305 assert_eq!(clf.class_variance.len(), 0);
306 }
307
308 #[test]
309 fn test_single_class() {
310 let mut clf = GaussianNB::<f64, i32>::new();
311
312 let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 2.0, 3.0, 3.0, 4.0]);
313 let y = DVector::from_column_slice(&[0, 0, 0]);
314 let dataset = Dataset::new(x, y);
315
316 let _ = clf.fit(&dataset);
317
318 assert_eq!(clf.class_freq.len(), 1);
319 assert_eq!(clf.class_mean.len(), 1);
320 assert_eq!(clf.class_variance.len(), 1);
321
322 let test_x = DMatrix::from_row_slice(2, 2, &[1.5, 2.5, 2.5, 3.5]);
323
324 let pred_y = clf.predict(&test_x).unwrap();
325
326 assert_eq!(pred_y, DVector::from_column_slice(&[0, 0]));
327 }
328
329 #[test]
330 fn test_predict_with_constant_feature() {
331 let mut clf = GaussianNB::<f64, i32>::new();
332
333 let x = DMatrix::from_row_slice(4, 2, &[0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
334 let y = DVector::from_vec(vec![0, 0, 1, 1]);
335
336 let x_new = DMatrix::from_row_slice(2, 2, &[0.0, 1.0, 1.0, 1.0]);
337 let dataset = Dataset::new(x, y);
338
339 let _ = clf.fit(&dataset);
340
341 let y_hat = clf.predict(&x_new).unwrap();
342
343 assert_eq!(y_hat.len(), 2);
344 assert_eq!(y_hat[0], 0);
345 assert_eq!(y_hat[1], 1);
346 }
347
348 #[test]
349 fn test_gaussian_nb() {
350 let mut clf = GaussianNB::<f64, i32>::new();
351
352 let x = DMatrix::from_row_slice(
353 4,
354 3,
355 &[
356 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
357 ],
358 );
359 let y = DVector::from_column_slice(&[0, 0, 1, 1]);
360 let dataset = Dataset::new(x, y);
361
362 let _ = clf.fit(&dataset);
363
364 assert_abs_diff_eq!(*clf.class_freq.get(&0).unwrap(), 0.5, epsilon = 1e-7);
365 assert_abs_diff_eq!(*clf.class_freq.get(&1).unwrap(), 0.5, epsilon = 1e-7);
366
367 let test_x = DMatrix::from_row_slice(2, 3, &[2.0, 3.0, 4.0, 6.0, 7.0, 8.0]);
368
369 let pred_y = clf.predict(&test_x).unwrap();
370
371 assert_eq!(pred_y, DVector::from_column_slice(&[0, 1]));
372 }
373}