rusty_machine/learning/toolkit/
rand_utils.rs1use rand::{Rng, thread_rng};
7
8pub fn reservoir_sample<T: Copy>(pool: &[T], reservoir_size: usize) -> Vec<T> {
17 assert!(pool.len() >= reservoir_size,
18 "Sample size is greater than total.");
19
20 let mut pool_mut = &pool[..];
21
22 let mut res = pool_mut[..reservoir_size].to_vec();
23 pool_mut = &pool_mut[reservoir_size..];
24
25 let mut ele_seen = reservoir_size;
26 let mut rng = thread_rng();
27
28 while pool_mut.len() > 0 {
29 ele_seen += 1;
30 let r = rng.gen_range(0, ele_seen);
31
32 let p_0 = pool_mut[0];
33 pool_mut = &pool_mut[1..];
34
35 if r < reservoir_size {
36 res[r] = p_0;
37 }
38 }
39
40 res
41}
42
43pub fn fisher_yates<T: Copy>(arr: &[T]) -> Vec<T> {
57 let n = arr.len();
58 let mut rng = thread_rng();
59
60 let mut shuffled_arr = Vec::with_capacity(n);
61
62 unsafe {
63 shuffled_arr.set_len(n);
66 }
67
68 for i in 0..n {
69 let j = rng.gen_range(0, i + 1);
70
71 if j != i {
73 let x = shuffled_arr[j];
76 shuffled_arr[i] = x;
77 }
78
79 shuffled_arr[j] = arr[i];
81 }
82
83 shuffled_arr
84}
85
86pub fn in_place_fisher_yates<T>(arr: &mut [T]) {
100 let n = arr.len();
101 let mut rng = thread_rng();
102
103 for i in 0..n {
104 let j = rng.gen_range(0, n - i);
106 arr.swap(i, i + j);
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn test_reservoir_sample() {
116 let a = vec![1, 2, 3, 4, 5, 6, 7];
117
118 let b = reservoir_sample(&a, 3);
119
120 assert_eq!(b.len(), 3);
121 }
122
123 #[test]
124 fn test_fisher_yates() {
125 let a = (0..10).collect::<Vec<_>>();
126
127 let b = fisher_yates(&a);
128
129 for val in a.iter() {
130 assert!(b.contains(val));
131 }
132 }
133
134 #[test]
135 fn test_in_place_fisher_yates() {
136 let mut a = (0..10).collect::<Vec<_>>();
137
138 in_place_fisher_yates(&mut a);
139
140 for val in 0..10 {
141 assert!(a.contains(&val));
142 }
143 }
144}