svm_sign_learner/
svm-sign_learner.rs

1extern crate rusty_machine;
2
3use rusty_machine::learning::svm::SVM;
4// Necessary for the training trait.
5use rusty_machine::learning::SupModel;
6use rusty_machine::learning::toolkit::kernel::HyperTan;
7
8use rusty_machine::linalg::Matrix;
9use rusty_machine::linalg::Vector;
10
11// Sign learner:
12//   * Model input a float number
13//   * Model output: A float representing the input sign.
14//       If the input is positive, the output is close to 1.0.
15//       If the input is negative, the output is close to -1.0.
16//   * Model generated with the SVM API.
17fn main() {
18    println!("Sign learner sample:");
19
20    println!("Training...");
21    // Training data
22    let inputs = Matrix::new(11, 1, vec![
23                             -0.1, -2., -9., -101., -666.7,
24                             0., 0.1, 1., 11., 99., 456.7
25                             ]);
26    let targets = Vector::new(vec![
27                              -1., -1., -1., -1., -1.,
28                              1., 1., 1., 1., 1., 1.
29                              ]);
30
31    // Trainee
32    let mut svm_mod = SVM::new(HyperTan::new(100., 0.), 0.3);
33    // Our train function returns a Result<(), E>
34    svm_mod.train(&inputs, &targets).unwrap();
35
36    println!("Evaluation...");
37    let mut hits = 0;
38    let mut misses = 0;
39    // Evaluation
40    //   Note: We could pass all input values at once to the `predict` method!
41    //         Here, we use a loop just to count and print logs.
42    for n in (-1000..1000).filter(|&x| x % 100 == 0) {
43        let nf = n as f64;
44        let input = Matrix::new(1, 1, vec![nf]);
45        let out = svm_mod.predict(&input).unwrap();
46        let res = if out[0] * nf > 0. {
47            hits += 1;
48            true
49        } else if nf == 0. {
50            hits += 1;
51            true
52        } else {
53            misses += 1;
54            false
55        };
56
57        println!("{} -> {}: {}", Matrix::data(&input)[0], out[0], res);
58    }
59
60    println!("Performance report:");
61    println!("Hits: {}, Misses: {}", hits, misses);
62    let hits_f = hits as f64;
63    let total = (hits + misses) as f64;
64    println!("Accuracy: {}", (hits_f / total) * 100.);
65}