1use crate::{
2 data::dataset::{Dataset, WholeNumber},
3 metrics::confusion::ClassificationMetrics,
4};
5use nalgebra::{DMatrix, DVector};
6use std::{
7 collections::{HashMap, HashSet},
8 error::Error,
9};
10
11pub struct CategoricalNB<T: WholeNumber> {
40 feature_class_freq: HashMap<T, DVector<HashMap<T, f64>>>,
41 label_class_freq: HashMap<T, f64>,
42 unique_feature_values_count: Vec<usize>,
43}
44
45impl<T: WholeNumber> Default for CategoricalNB<T> {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl<T: WholeNumber> ClassificationMetrics<T> for CategoricalNB<T> {}
52
53impl<T: WholeNumber> CategoricalNB<T> {
54 pub fn new() -> Self {
63 Self {
64 feature_class_freq: HashMap::new(),
65 label_class_freq: HashMap::new(),
66 unique_feature_values_count: Vec::new(),
67 }
68 }
69
70 pub fn feature_class_freq(&self) -> &HashMap<T, DVector<HashMap<T, f64>>> {
79 &self.feature_class_freq
80 }
81
82 pub fn label_class_freq(&self) -> &HashMap<T, f64> {
91 &self.label_class_freq
92 }
93
94 pub fn fit(&mut self, dataset: &Dataset<T, T>) -> Result<String, Box<dyn Error>> {
109 let (x, y) = dataset.into_parts();
110 let y_classes = y.iter().cloned().collect::<HashSet<_>>();
111
112 let mut unique_feature_values_count_temp = vec![HashSet::new(); x.ncols()];
113
114 x.column_iter().enumerate().for_each(|(idx, feature)| {
115 feature.iter().for_each(|&val| {
116 unique_feature_values_count_temp[idx].insert(val);
117 })
118 });
119
120 self.unique_feature_values_count = unique_feature_values_count_temp
121 .iter()
122 .map(|set| set.len())
123 .collect::<Vec<_>>();
124
125 for y_class in y_classes {
126 let class_mask = y.map(|label| label == y_class);
127 let class_indices = class_mask
128 .iter()
129 .enumerate()
130 .filter(|&(_, &value)| value)
131 .map(|(index, _)| index)
132 .collect::<Vec<_>>();
133
134 let x_y_class = x.select_rows(class_indices.as_slice());
135
136 let mut all_features_freq = DVector::from_element(x.ncols(), HashMap::new());
137 for (idx, feature) in x_y_class.column_iter().enumerate() {
138 let feature_count = feature.iter().fold(HashMap::new(), |mut acc, &val| {
139 *acc.entry(val).or_insert(0) += 1;
140 acc
141 });
142 let total_count =
143 class_indices.len() as f64 + self.unique_feature_values_count[idx] as f64;
144 let feature_freq = feature_count
145 .into_iter()
146 .map(|(class, count)| (class, (count as f64 + 1.0 / total_count)))
147 .collect();
148 all_features_freq[idx] = feature_freq;
149 }
150
151 let label_class_freq = class_indices.len() as f64 / y.nrows() as f64;
152
153 self.label_class_freq.insert(y_class, label_class_freq);
154 self.feature_class_freq.insert(y_class, all_features_freq);
155 }
156
157 Ok("Finished fitting".into())
158 }
159
160 fn predict_single(&self, x: &DVector<T>) -> Result<T, Box<dyn Error>> {
161 let mut max_prob = f64::NEG_INFINITY;
162 let mut max_class = T::from_i8(0).unwrap();
163
164 for (y_class, label_freq) in &self.label_class_freq {
165 let mut prob = label_freq.ln();
166
167 for (idx, feature) in x.iter().enumerate() {
168 let feature_probs = &self
169 .feature_class_freq
170 .get(y_class)
171 .ok_or(format!("Class {:?} wasn't obtained.", y_class))?[idx];
172
173 let total_feature_count = self.label_class_freq.values().sum::<f64>()
174 + self.unique_feature_values_count[idx] as f64;
175 let feature_prob = feature_probs
176 .get(feature)
177 .unwrap_or(&(1.0 / total_feature_count))
178 .ln();
179
180 prob += feature_prob;
181 }
182
183 if prob > max_prob {
184 max_prob = prob;
185 max_class = *y_class;
186 }
187 }
188
189 Ok(max_class)
190 }
191
192 pub fn predict(&self, x: &DMatrix<T>) -> Result<DVector<T>, Box<dyn Error>> {
208 let mut y_pred = Vec::new();
209
210 for i in 0..x.nrows() {
211 let x_row = x.row(i).transpose();
212 let y_class = self.predict_single(&x_row)?;
213 y_pred.push(y_class);
214 }
215 Ok(DVector::from_vec(y_pred))
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use nalgebra::{DMatrix, DVector};
223
224 #[test]
225 fn test_new() {
226 let model = CategoricalNB::<i32>::new();
227
228 assert!(model.feature_class_freq.is_empty());
229 assert!(model.label_class_freq.is_empty());
230 }
231
232 #[test]
233 fn test_fit() {
234 let mut model = CategoricalNB::<i32>::new();
235
236 let x = DMatrix::from_row_slice(3, 3, &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
237 let y = DVector::from_vec(vec![1, 2, 3]);
238 let dataset = Dataset::new(x, y);
239
240 let result = model.fit(&dataset);
241
242 assert!(result.is_ok());
243 assert_eq!(model.label_class_freq.len(), 3);
244 assert_eq!(model.feature_class_freq.len(), 3);
245 }
246
247 #[test]
248 fn test_predict_single() {
249 let mut model = CategoricalNB::<i32>::new();
250
251 let x = DMatrix::from_row_slice(4, 2, &[1, 0, 1, 1, 0, 0, 0, 1]);
253 let y = DVector::from_vec(vec![0, 0, 1, 1]);
254 let dataset = Dataset::new(x.clone(), y);
255 model.fit(&dataset).unwrap();
256
257 let test_instance = x.row(0).transpose();
259 let result = model.predict_single(&test_instance).unwrap();
260
261 assert_eq!(result, 0);
263 }
264
265 #[test]
266 fn test_predict_with_unseen_feature_value() {
267 let mut model = CategoricalNB::<i32>::new();
268
269 let x = DMatrix::from_row_slice(4, 2, &[1, 0, 1, 1, 0, 0, 0, 1]);
271 let y = DVector::from_vec(vec![0, 0, 1, 1]);
272 let dataset = Dataset::new(x, y);
273 model.fit(&dataset).unwrap();
274
275 let test_instance = DVector::from_vec(vec![2, 2]); let result = model.predict_single(&test_instance).unwrap();
278
279 assert!(result == 0 || result == 1);
282 }
283
284 #[test]
285 fn test_predict() {
286 let mut model = CategoricalNB::<i32>::new();
287
288 let x = DMatrix::from_row_slice(3, 3, &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
289 let y = DVector::from_vec(vec![3, 2, 1]);
290 let dataset = Dataset::new(x.clone(), y.clone());
291
292 model.fit(&dataset).unwrap();
293 let result = model.predict(&x);
294 assert!(result.is_ok());
295 assert_eq!(result.unwrap(), y);
296 }
297}