rusty_ai/bayes/
categorical.rs

1use crate::{
2    data::dataset::{Dataset, WholeNumber},
3    metrics::confusion::ClassificationMetrics,
4};
5use nalgebra::{DMatrix, DVector};
6use std::{
7    collections::{HashMap, HashSet},
8    error::Error,
9};
10
11/// Implementation of the Categorical Naive Bayes classifier.
12///
13/// This struct represents a Categorical Naive Bayes classifier, which is a probabilistic
14/// classifier that assumes independence between features given the class label. It is
15/// specifically designed for categorical features.
16///
17/// # Example
18///
19/// ```
20/// use rusty_ai::bayes::categorical::CategoricalNB;
21/// use rusty_ai::data::dataset::Dataset;
22/// use nalgebra::{DMatrix, DVector};
23///
24/// // Create a new CategoricalNB classifier
25/// let mut model = CategoricalNB::new();
26///
27/// // Fit the classifier to a dataset
28/// let x = DMatrix::from_row_slice(2, 3, &[1, 2, 3, 2, 3, 4]);
29/// let y = DVector::from_vec(vec![0, 1]);
30/// let dataset = Dataset::new(x, y);
31/// model.fit(&dataset).unwrap();
32///
33/// // Predict the class labels for new data
34/// let x_test = DMatrix::from_row_slice(2, 3, &[1, 3, 4, 2, 2, 3]);
35/// let predictions = model.predict(&x_test).unwrap();
36/// assert_eq!(predictions, DVector::from_vec(vec![1,0]))
37/// ```
38
39pub struct CategoricalNB<T: WholeNumber> {
40    feature_class_freq: HashMap<T, DVector<HashMap<T, f64>>>,
41    label_class_freq: HashMap<T, f64>,
42    unique_feature_values_count: Vec<usize>,
43}
44
45impl<T: WholeNumber> Default for CategoricalNB<T> {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl<T: WholeNumber> ClassificationMetrics<T> for CategoricalNB<T> {}
52
53impl<T: WholeNumber> CategoricalNB<T> {
54    /// Creates a new instance of the CategoricalNB classifier.
55    ///
56    /// This function initializes the classifier with empty frequency maps and an empty
57    /// vector to store the count of unique feature values.
58    ///
59    /// # Returns
60    ///
61    /// A new instance of the CategoricalNB classifier.
62    pub fn new() -> Self {
63        Self {
64            feature_class_freq: HashMap::new(),
65            label_class_freq: HashMap::new(),
66            unique_feature_values_count: Vec::new(),
67        }
68    }
69
70    /// Returns a reference to the feature class frequency map.
71    ///
72    /// This function returns a reference to the map that stores the frequency of each
73    /// feature value for each class label.
74    ///
75    /// # Returns
76    ///
77    /// A reference to the feature class frequency map.
78    pub fn feature_class_freq(&self) -> &HashMap<T, DVector<HashMap<T, f64>>> {
79        &self.feature_class_freq
80    }
81
82    /// Returns a reference to the label class frequency map.
83    ///
84    /// This function returns a reference to the map that stores the frequency of each
85    /// class label.
86    ///
87    /// # Returns
88    ///
89    /// A reference to the label class frequency map.
90    pub fn label_class_freq(&self) -> &HashMap<T, f64> {
91        &self.label_class_freq
92    }
93
94    /// Fits the classifier to a dataset.
95    ///
96    /// This function fits the classifier to the given dataset by calculating the
97    /// frequency of each feature value for each class label and the frequency of each
98    /// class label. It also calculates the count of unique feature values for each
99    /// feature.
100    ///
101    /// # Arguments
102    ///
103    /// * `dataset` - The dataset to fit the classifier to.
104    ///
105    /// # Returns
106    ///
107    /// A `Result` indicating whether the fitting process was successful or an error occurred.
108    pub fn fit(&mut self, dataset: &Dataset<T, T>) -> Result<String, Box<dyn Error>> {
109        let (x, y) = dataset.into_parts();
110        let y_classes = y.iter().cloned().collect::<HashSet<_>>();
111
112        let mut unique_feature_values_count_temp = vec![HashSet::new(); x.ncols()];
113
114        x.column_iter().enumerate().for_each(|(idx, feature)| {
115            feature.iter().for_each(|&val| {
116                unique_feature_values_count_temp[idx].insert(val);
117            })
118        });
119
120        self.unique_feature_values_count = unique_feature_values_count_temp
121            .iter()
122            .map(|set| set.len())
123            .collect::<Vec<_>>();
124
125        for y_class in y_classes {
126            let class_mask = y.map(|label| label == y_class);
127            let class_indices = class_mask
128                .iter()
129                .enumerate()
130                .filter(|&(_, &value)| value)
131                .map(|(index, _)| index)
132                .collect::<Vec<_>>();
133
134            let x_y_class = x.select_rows(class_indices.as_slice());
135
136            let mut all_features_freq = DVector::from_element(x.ncols(), HashMap::new());
137            for (idx, feature) in x_y_class.column_iter().enumerate() {
138                let feature_count = feature.iter().fold(HashMap::new(), |mut acc, &val| {
139                    *acc.entry(val).or_insert(0) += 1;
140                    acc
141                });
142                let total_count =
143                    class_indices.len() as f64 + self.unique_feature_values_count[idx] as f64;
144                let feature_freq = feature_count
145                    .into_iter()
146                    .map(|(class, count)| (class, (count as f64 + 1.0 / total_count)))
147                    .collect();
148                all_features_freq[idx] = feature_freq;
149            }
150
151            let label_class_freq = class_indices.len() as f64 / y.nrows() as f64;
152
153            self.label_class_freq.insert(y_class, label_class_freq);
154            self.feature_class_freq.insert(y_class, all_features_freq);
155        }
156
157        Ok("Finished fitting".into())
158    }
159
160    fn predict_single(&self, x: &DVector<T>) -> Result<T, Box<dyn Error>> {
161        let mut max_prob = f64::NEG_INFINITY;
162        let mut max_class = T::from_i8(0).unwrap();
163
164        for (y_class, label_freq) in &self.label_class_freq {
165            let mut prob = label_freq.ln();
166
167            for (idx, feature) in x.iter().enumerate() {
168                let feature_probs = &self
169                    .feature_class_freq
170                    .get(y_class)
171                    .ok_or(format!("Class {:?} wasn't obtained.", y_class))?[idx];
172
173                let total_feature_count = self.label_class_freq.values().sum::<f64>()
174                    + self.unique_feature_values_count[idx] as f64;
175                let feature_prob = feature_probs
176                    .get(feature)
177                    .unwrap_or(&(1.0 / total_feature_count))
178                    .ln();
179
180                prob += feature_prob;
181            }
182
183            if prob > max_prob {
184                max_prob = prob;
185                max_class = *y_class;
186            }
187        }
188
189        Ok(max_class)
190    }
191
192    /// Predicts the class labels for a matrix of feature values.
193    ///
194    /// This function predicts the class labels for each row in the given matrix of
195    /// feature values. It uses the fitted model to calculate the probability of each
196    /// class label for each row and selects the class label with the highest probability
197    /// as the predicted label.
198    ///
199    /// # Arguments
200    ///
201    /// * `x` - The matrix of feature values.
202    ///
203    /// # Returns
204    ///
205    /// A `Result` containing a vector of predicted class labels or an error if the
206    /// prediction process failed.
207    pub fn predict(&self, x: &DMatrix<T>) -> Result<DVector<T>, Box<dyn Error>> {
208        let mut y_pred = Vec::new();
209
210        for i in 0..x.nrows() {
211            let x_row = x.row(i).transpose();
212            let y_class = self.predict_single(&x_row)?;
213            y_pred.push(y_class);
214        }
215        Ok(DVector::from_vec(y_pred))
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use nalgebra::{DMatrix, DVector};
223
224    #[test]
225    fn test_new() {
226        let model = CategoricalNB::<i32>::new();
227
228        assert!(model.feature_class_freq.is_empty());
229        assert!(model.label_class_freq.is_empty());
230    }
231
232    #[test]
233    fn test_fit() {
234        let mut model = CategoricalNB::<i32>::new();
235
236        let x = DMatrix::from_row_slice(3, 3, &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
237        let y = DVector::from_vec(vec![1, 2, 3]);
238        let dataset = Dataset::new(x, y);
239
240        let result = model.fit(&dataset);
241
242        assert!(result.is_ok());
243        assert_eq!(model.label_class_freq.len(), 3);
244        assert_eq!(model.feature_class_freq.len(), 3);
245    }
246
247    #[test]
248    fn test_predict_single() {
249        let mut model = CategoricalNB::<i32>::new();
250
251        // Create a simple dataset and fit the model
252        let x = DMatrix::from_row_slice(4, 2, &[1, 0, 1, 1, 0, 0, 0, 1]);
253        let y = DVector::from_vec(vec![0, 0, 1, 1]);
254        let dataset = Dataset::new(x.clone(), y);
255        model.fit(&dataset).unwrap();
256
257        // Predict a single instance
258        let test_instance = x.row(0).transpose();
259        let result = model.predict_single(&test_instance).unwrap();
260
261        // Check if the prediction matches the expected class
262        assert_eq!(result, 0);
263    }
264
265    #[test]
266    fn test_predict_with_unseen_feature_value() {
267        let mut model = CategoricalNB::<i32>::new();
268
269        // Create a simple dataset and fit the model
270        let x = DMatrix::from_row_slice(4, 2, &[1, 0, 1, 1, 0, 0, 0, 1]);
271        let y = DVector::from_vec(vec![0, 0, 1, 1]);
272        let dataset = Dataset::new(x, y);
273        model.fit(&dataset).unwrap();
274
275        // Predict an instance with an unseen feature value
276        let test_instance = DVector::from_vec(vec![2, 2]); // Unseen feature values
277        let result = model.predict_single(&test_instance).unwrap();
278
279        // Just check if it produces a result without errors for now
280        // The correctness of this test depends on your Laplace smoothing implementation
281        assert!(result == 0 || result == 1);
282    }
283
284    #[test]
285    fn test_predict() {
286        let mut model = CategoricalNB::<i32>::new();
287
288        let x = DMatrix::from_row_slice(3, 3, &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
289        let y = DVector::from_vec(vec![3, 2, 1]);
290        let dataset = Dataset::new(x.clone(), y.clone());
291
292        model.fit(&dataset).unwrap();
293        let result = model.predict(&x);
294        assert!(result.is_ok());
295        assert_eq!(result.unwrap(), y);
296    }
297}