nnet_and_gate/
nnet-and_gate.rs

1extern crate rusty_machine;
2extern crate rand;
3
4use rand::{random, Closed01};
5use std::vec::Vec;
6
7use rusty_machine::learning::nnet::{NeuralNet, BCECriterion};
8use rusty_machine::learning::toolkit::regularization::Regularization;
9use rusty_machine::learning::optim::grad_desc::StochasticGD;
10
11use rusty_machine::linalg::Matrix;
12use rusty_machine::learning::SupModel;
13
14// AND gate
15fn main() {
16    println!("AND gate learner sample:");
17
18    const THRESHOLD: f64 = 0.7;
19
20    const SAMPLES: usize = 10000;
21    println!("Generating {} training data and labels...", SAMPLES as u32);
22
23    let mut input_data = Vec::with_capacity(SAMPLES * 2);
24    let mut label_data = Vec::with_capacity(SAMPLES);
25
26    for _ in 0..SAMPLES {
27        // The two inputs are "signals" between 0 and 1
28        let Closed01(left) = random::<Closed01<f64>>();
29        let Closed01(right) = random::<Closed01<f64>>();
30        input_data.push(left);
31        input_data.push(right);
32        if left > THRESHOLD && right > THRESHOLD {
33            label_data.push(1.0);
34        } else {
35            label_data.push(0.0)
36        }
37    }
38
39    let inputs = Matrix::new(SAMPLES, 2, input_data);
40    let targets = Matrix::new(SAMPLES, 1, label_data);
41
42    let layers = &[2, 1];
43    let criterion = BCECriterion::new(Regularization::L2(0.));
44    let mut model = NeuralNet::new(layers, criterion, StochasticGD::default());
45
46    println!("Training...");
47    // Our train function returns a Result<(), E>
48    model.train(&inputs, &targets).unwrap();
49
50    let test_cases = vec![
51        0.0, 0.0,
52        0.0, 1.0,
53        1.0, 1.0,
54        1.0, 0.0,
55        ];
56    let expected = vec![
57        0.0,
58        0.0,
59        1.0,
60        0.0,
61        ];
62    let test_inputs = Matrix::new(test_cases.len() / 2, 2, test_cases);
63    let res = model.predict(&test_inputs).unwrap();
64
65    println!("Evaluation...");
66    let mut hits = 0;
67    let mut misses = 0;
68    // Evaluation
69    println!("Got\tExpected");
70    for (idx, prediction) in res.into_vec().iter().enumerate() {
71        println!("{:.2}\t{}", prediction, expected[idx]);
72        if (prediction - 0.5) * (expected[idx] - 0.5) > 0. {
73            hits += 1;
74        } else {
75            misses += 1;
76        }
77    }
78
79    println!("Hits: {}, Misses: {}", hits, misses);
80    let hits_f = hits as f64;
81    let total = (hits + misses) as f64;
82    println!("Accuracy: {}%", (hits_f / total) * 100.);
83}