rusty_machine/learning/toolkit/
rand_utils.rs

1//! Utility functions for random functionality.
2//!
3//! This module provides sampling and shuffling which are used
4//! within the learning modules.
5
6use rand::{Rng, thread_rng};
7
8/// ```
9/// use rusty_machine::learning::toolkit::rand_utils;
10///
11/// let mut pool = &mut [1,2,3,4];
12/// let sample = rand_utils::reservoir_sample(pool, 3);
13///
14/// println!("{:?}", sample);
15/// ```
16pub fn reservoir_sample<T: Copy>(pool: &[T], reservoir_size: usize) -> Vec<T> {
17    assert!(pool.len() >= reservoir_size,
18            "Sample size is greater than total.");
19
20    let mut pool_mut = &pool[..];
21
22    let mut res = pool_mut[..reservoir_size].to_vec();
23    pool_mut = &pool_mut[reservoir_size..];
24
25    let mut ele_seen = reservoir_size;
26    let mut rng = thread_rng();
27
28    while pool_mut.len() > 0 {
29        ele_seen += 1;
30        let r = rng.gen_range(0, ele_seen);
31
32        let p_0 = pool_mut[0];
33        pool_mut = &pool_mut[1..];
34
35        if r < reservoir_size {
36            res[r] = p_0;
37        }
38    }
39
40    res
41}
42
43/// The inside out Fisher-Yates algorithm.
44///
45/// # Examples
46///
47/// ```
48/// use rusty_machine::learning::toolkit::rand_utils;
49///
50/// // Collect the numbers 0..5
51/// let a = (0..5).collect::<Vec<_>>();
52///
53/// // Perform a Fisher-Yates shuffle to get a random permutation
54/// let permutation = rand_utils::fisher_yates(&a);
55/// ```
56pub fn fisher_yates<T: Copy>(arr: &[T]) -> Vec<T> {
57    let n = arr.len();
58    let mut rng = thread_rng();
59
60    let mut shuffled_arr = Vec::with_capacity(n);
61
62    unsafe {
63        // We set the length here
64        // We only access data which has been initialized in the algorithm
65        shuffled_arr.set_len(n);
66    }
67
68    for i in 0..n {
69        let j = rng.gen_range(0, i + 1);
70
71        // If j isn't the last point in the active shuffled array
72        if j != i {
73            // Copy value at position j to the end of the shuffled array
74            // This is safe as we only read initialized data (j < i)
75            let x = shuffled_arr[j];
76            shuffled_arr[i] = x;
77        }
78
79        // Place value at end of active array into shuffled array
80        shuffled_arr[j] = arr[i];
81    }
82
83    shuffled_arr
84}
85
86/// The in place Fisher-Yates shuffle.
87///
88/// # Examples
89///
90/// ```
91/// use rusty_machine::learning::toolkit::rand_utils;
92///
93/// // Collect the numbers 0..5
94/// let mut a = (0..5).collect::<Vec<_>>();
95///
96/// // Permute the values in place with Fisher-Yates
97/// rand_utils::in_place_fisher_yates(&mut a);
98/// ```
99pub fn in_place_fisher_yates<T>(arr: &mut [T]) {
100    let n = arr.len();
101    let mut rng = thread_rng();
102
103    for i in 0..n {
104        // Swap i with a random point after it
105        let j = rng.gen_range(0, n - i);
106        arr.swap(i, i + j);
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn test_reservoir_sample() {
116        let a = vec![1, 2, 3, 4, 5, 6, 7];
117
118        let b = reservoir_sample(&a, 3);
119
120        assert_eq!(b.len(), 3);
121    }
122
123    #[test]
124    fn test_fisher_yates() {
125        let a = (0..10).collect::<Vec<_>>();
126
127        let b = fisher_yates(&a);
128
129        for val in a.iter() {
130            assert!(b.contains(val));
131        }
132    }
133
134    #[test]
135    fn test_in_place_fisher_yates() {
136        let mut a = (0..10).collect::<Vec<_>>();
137
138        in_place_fisher_yates(&mut a);
139
140        for val in 0..10 {
141            assert!(a.contains(&val));
142        }
143    }
144}