naive_bayes_dogs/
naive_bayes_dogs.rs1extern crate rusty_machine;
2extern crate rand;
3
4use rand::Rand;
5use rand::distributions::Sample;
6use rand::distributions::normal::Normal;
7use rusty_machine::learning::naive_bayes::{self, NaiveBayes};
8use rusty_machine::linalg::{Matrix, BaseMatrix};
9use rusty_machine::learning::SupModel;
10
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq)]
13enum Color {
14 Red,
15 White,
16}
17
18#[derive(Clone, Debug)]
19struct Dog {
20 color: Color,
21 friendliness: f64,
22 furriness: f64,
23 speed: f64,
24}
25
26impl Rand for Dog {
27 fn rand<R: rand::Rng>(rng: &mut R) -> Self {
29 let mut red_dog_friendliness = Normal::new(0., 1.);
32 let mut red_dog_furriness = Normal::new(0., 1.);
33 let mut red_dog_speed = Normal::new(0., 1.);
34
35 let mut white_dog_friendliness = Normal::new(1., 1.);
36 let mut white_dog_furriness = Normal::new(1., 1.);
37 let mut white_dog_speed = Normal::new(-1., 1.);
38
39 let coin: f64 = rng.gen();
41 let color = if coin < 0.5 { Color::Red } else { Color::White };
42
43 match color {
44 Color::Red => {
45 Dog {
46 color: Color::Red,
47 friendliness: red_dog_friendliness.sample(rng),
49 furriness: red_dog_furriness.sample(rng),
50 speed: red_dog_speed.sample(rng),
51 }
52 },
53 Color::White => {
54 Dog {
55 color: Color::White,
56 friendliness: white_dog_friendliness.sample(rng),
57 furriness: white_dog_furriness.sample(rng),
58 speed: white_dog_speed.sample(rng),
59 }
60 },
61 }
62 }
63}
64
65fn generate_dog_data(training_set_size: u32, test_set_size: u32)
66 -> (Matrix<f64>, Matrix<f64>, Matrix<f64>, Vec<Dog>) {
67 let mut randomness = rand::StdRng::new()
68 .expect("we should be able to get an RNG");
69 let rng = &mut randomness;
70
71 let training_dogs = (0..training_set_size)
73 .map(|_| { Dog::rand(rng) })
74 .collect::<Vec<_>>();
75
76 let test_dogs = (0..test_set_size)
79 .map(|_| { Dog::rand(rng) })
80 .collect::<Vec<_>>();
81
82 let training_data: Vec<f64> = training_dogs.iter()
87 .flat_map(|dog| vec![dog.friendliness, dog.furriness, dog.speed])
88 .collect();
89 let training_matrix: Matrix<f64> = training_data.chunks(3).collect();
90 let target_data: Vec<f64> = training_dogs.iter()
91 .flat_map(|dog| match dog.color {
92 Color::Red => vec![1., 0.],
93 Color::White => vec![0., 1.],
94 })
95 .collect();
96 let target_matrix: Matrix<f64> = target_data.chunks(2).collect();
97
98 let test_data: Vec<f64> = test_dogs.iter()
100 .flat_map(|dog| vec![dog.friendliness, dog.furriness, dog.speed])
101 .collect();
102 let test_matrix: Matrix<f64> = test_data.chunks(3).collect();
103
104 (training_matrix, target_matrix, test_matrix, test_dogs)
105}
106
107fn evaluate_prediction(hits: &mut u32, dog: &Dog, prediction: &[f64]) -> (Color, bool) {
108 let predicted_color = dog.color;
109 let actual_color = if prediction[0] == 1. {
110 Color::Red
111 } else {
112 Color::White
113 };
114 let accurate = predicted_color == actual_color;
115 if accurate {
116 *hits += 1;
117 }
118 (actual_color, accurate)
119}
120
121fn main() {
122 let (training_set_size, test_set_size) = (1000, 1000);
123 let (training_matrix, target_matrix, test_matrix, test_dogs) = generate_dog_data(training_set_size, test_set_size);
125
126 let mut model = NaiveBayes::<naive_bayes::Gaussian>::new();
128 model.train(&training_matrix, &target_matrix)
129 .expect("failed to train model of dogs");
130
131 let predictions = model.predict(&test_matrix)
133 .expect("failed to predict dogs!?");
134
135 let mut hits = 0;
137 let unprinted_total = test_set_size.saturating_sub(10) as usize;
138 for (dog, prediction) in test_dogs.iter().zip(predictions.iter_rows()).take(unprinted_total) {
139 evaluate_prediction(&mut hits, dog, prediction);
140 }
141
142 if unprinted_total > 0 {
143 println!("...");
144 }
145
146 for (dog, prediction) in test_dogs.iter().zip(predictions.iter_rows()).skip(unprinted_total) {
147 let (actual_color, accurate) = evaluate_prediction(&mut hits, dog, prediction);
148 println!("Predicted: {:?}; Actual: {:?}; Accurate? {:?}",
149 dog.color, actual_color, accurate);
150 }
151
152 println!("Accuracy: {}/{} = {:.1}%", hits, test_set_size,
153 (f64::from(hits))/(f64::from(test_set_size)) * 100.);
154}