nnet_and_gate/
nnet-and_gate.rs1extern crate rusty_machine;
2extern crate rand;
3
4use rand::{random, Closed01};
5use std::vec::Vec;
6
7use rusty_machine::learning::nnet::{NeuralNet, BCECriterion};
8use rusty_machine::learning::toolkit::regularization::Regularization;
9use rusty_machine::learning::optim::grad_desc::StochasticGD;
10
11use rusty_machine::linalg::Matrix;
12use rusty_machine::learning::SupModel;
13
14fn main() {
16 println!("AND gate learner sample:");
17
18 const THRESHOLD: f64 = 0.7;
19
20 const SAMPLES: usize = 10000;
21 println!("Generating {} training data and labels...", SAMPLES as u32);
22
23 let mut input_data = Vec::with_capacity(SAMPLES * 2);
24 let mut label_data = Vec::with_capacity(SAMPLES);
25
26 for _ in 0..SAMPLES {
27 let Closed01(left) = random::<Closed01<f64>>();
29 let Closed01(right) = random::<Closed01<f64>>();
30 input_data.push(left);
31 input_data.push(right);
32 if left > THRESHOLD && right > THRESHOLD {
33 label_data.push(1.0);
34 } else {
35 label_data.push(0.0)
36 }
37 }
38
39 let inputs = Matrix::new(SAMPLES, 2, input_data);
40 let targets = Matrix::new(SAMPLES, 1, label_data);
41
42 let layers = &[2, 1];
43 let criterion = BCECriterion::new(Regularization::L2(0.));
44 let mut model = NeuralNet::new(layers, criterion, StochasticGD::default());
45
46 println!("Training...");
47 model.train(&inputs, &targets).unwrap();
49
50 let test_cases = vec![
51 0.0, 0.0,
52 0.0, 1.0,
53 1.0, 1.0,
54 1.0, 0.0,
55 ];
56 let expected = vec![
57 0.0,
58 0.0,
59 1.0,
60 0.0,
61 ];
62 let test_inputs = Matrix::new(test_cases.len() / 2, 2, test_cases);
63 let res = model.predict(&test_inputs).unwrap();
64
65 println!("Evaluation...");
66 let mut hits = 0;
67 let mut misses = 0;
68 println!("Got\tExpected");
70 for (idx, prediction) in res.into_vec().iter().enumerate() {
71 println!("{:.2}\t{}", prediction, expected[idx]);
72 if (prediction - 0.5) * (expected[idx] - 0.5) > 0. {
73 hits += 1;
74 } else {
75 misses += 1;
76 }
77 }
78
79 println!("Hits: {}, Misses: {}", hits, misses);
80 let hits_f = hits as f64;
81 let total = (hits + misses) as f64;
82 println!("Accuracy: {}%", (hits_f / total) * 100.);
83}