1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use std::cmp::Ordering::Equal;

/// Sorts two vectors by the first one. This is used to sort target by feature and it is at the
/// core of the decision tree algorithm. 
pub(crate) fn sort_two_vectors(a: &[f32], b: &[f32]) -> (Vec<f32>, Vec<f32>) {
    let a_sorter = permutation::sort_by(a, |a, b| a.partial_cmp(b).unwrap_or(Equal));

    let a = a_sorter.apply_slice(a);
    let b = a_sorter.apply_slice(b);
    (a, b)
}

pub(crate) fn float_avg(x: &[f32]) -> f32 {
    x.iter().sum::<f32>() / x.len() as f32
}

/// computes the classification threshold for a given vector. This is used for testing the
#[cfg(test)]
pub(crate) fn classification_threshold(x: &[f32], clf_threshold: f32) -> Vec<f32> {
    x.iter()
        .map(|&x| if x >= clf_threshold { 1.0 } else { 0.0 })
        .collect()
}

/// computes the mean squared error between two vectors used for testing regression case.
pub fn r2(x_true: &[f32], x_pred: &[f32]) -> f32 {
    let mse: f32 = x_true
        .iter()
        .zip(x_pred)
        .map(|(xt, xp)| (xt - xp).powf(2.0))
        .sum();

    let avg = float_avg(x_true);
    let var: f32 = x_true.iter().map(|x| (x - avg).powf(2.0)).sum();

    1.0 - mse / var
}

/// computes the accuracy of a binary classification. Used for testing.
pub fn accuracy(x_true: &[f32], x_pred: &[f32]) -> f32 {
    x_true
        .iter()
        .zip(x_pred)
        .map(|(xt, xp)| ((xt == xp) as i32 as f32))
        .sum::<f32>()
        / x_true.len() as f32
}

pub(crate) fn get_rng(maybe_seed: Option<u64>, offset: u64) -> rand::rngs::StdRng {
    match maybe_seed {
        Some(seed) => rand::SeedableRng::seed_from_u64(seed + offset),
        None => rand::SeedableRng::from_entropy(),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_accuracy() {
        let x_true = vec![1.0, 0.0, 0.0, 0.0];
        let x_predict = vec![1.0, 1.0, 0.0, 0.0];
        let expect = 0.75;
        assert_eq!(expect, accuracy(&x_true, &x_predict));
    }

    #[test]
    fn test_float_avg() {
        let vector = vec![1.0, 2.0, 3.0];
        let expect = 2.0;
        assert_eq!(expect, float_avg(&vector));
    }

    #[test]
    fn test_sort_two_vectors() {
        let vec1 = vec![2.0, 3.0, 1.0];
        let vec2 = vec![6.0, 5.0, 4.0];

        let expect = (vec![1.0, 2.0, 3.0], vec![4.0, 6.0, 5.0]);

        let got = sort_two_vectors(&vec1, &vec2);
        assert_eq!(expect, got);
    }
}