wheel_resample/
lib.rs

1//! Re-sampling functions for weighted sampling
2//!
3//! # Example
4//!
5//! ```
6//! use wheel_resample::resample;
7//!
8//! let mut rng = rand::thread_rng();
9//! let weights = [0.1, 0.2, 0.3, 0.8];
10//! let population = vec![1, 2, 3, 4];
11//! let samples = resample(&mut rng, &weights, &population);
12//!
13//! assert_eq!(samples.len(), population.len());
14//!
15//! // Make sure all samples are in the population
16//! assert!(samples.iter().all(|s| population.contains(s)));
17//! ```
18use num_traits::float::Float;
19use rand::{
20    distributions::{uniform::SampleUniform, Distribution},
21    Rng,
22};
23
24use rand_distr::Uniform;
25
26/// Returns a vector of n indices sampled according to the weights slice.
27///
28/// # Example
29///
30/// ```
31/// use wheel_resample::resample_idx;
32///
33/// let mut rng = rand::thread_rng();
34/// let weights = [0.1, 0.2, 0.3, 0.8];
35///
36/// let sample_idx = resample_idx(&mut rng, &weights, weights.len());
37/// assert_eq!(sample_idx.len(), weights.len());
38///
39/// let sample_2_idx = resample_idx(&mut rng, &weights, 2);
40/// assert_eq!(sample_2_idx.len(), 2);
41/// ```
42pub fn resample_idx<R, W>(rng: &mut R, weights: &[W], n: usize) -> Vec<usize>
43where
44    R: Rng,
45    W: SampleUniform + Float,
46{
47    let resampler = Resampler::new(rng, weights);
48    resampler.into_iter().take(n).collect()
49}
50
51/// Returns a vector of weighted samples drawn from the population vector.
52///
53/// # Example
54///
55/// ```
56/// use wheel_resample::resample;
57///
58/// let mut rng = rand::thread_rng();
59/// let weights = [0.1, 0.2, 0.3, 0.8];
60/// let population = vec![1, 2, 3, 4];
61/// let samples = resample(&mut rng, &weights, &population);
62///
63/// assert_eq!(samples.len(), population.len());
64///
65/// // Make sure all samples are in the population
66/// assert!(samples.iter().all(|s| population.contains(s)));
67/// ```
68pub fn resample<R, T, W>(rng: &mut R, weights: &[W], population: &[T]) -> Vec<T>
69where
70    R: Rng,
71    T: Clone,
72    W: SampleUniform + Float,
73{
74    let indices = resample_idx(rng, weights, population.len());
75
76    indices.iter().map(|&i| population[i].clone()).collect()
77}
78
79/// The Resampler can be turned into an Iterator to contineously pull sample indices
80///
81/// # Example
82///
83/// ```
84/// use wheel_resample::Resampler;
85///
86/// let mut rng = rand::thread_rng();
87/// let weights = [0.1, 0.2, 0.3, 0.8];
88/// let resampler = Resampler::new(&mut rng, &weights);
89///
90/// let population = vec![1, 2, 3, 4];
91/// let samples = resampler.into_iter().take(4).map(|i| population[i].clone()).collect::<Vec<u32>>();
92///
93/// // Make sure we got four samples
94/// assert_eq!(samples.len(), 4);
95///
96/// // Make sure all samples come from the population
97/// assert!(samples.iter().all(|s| population.contains(s)));
98/// ```
99///
100pub struct Resampler<'a, R: Rng, W: Float> {
101    rng: &'a mut R,
102    weights: &'a [W],
103}
104
105impl<'a, R: Rng, W: Float> Resampler<'a, R, W> {
106    /// Create Resampler instance from random generator and weights
107    pub fn new(rng: &'a mut R, weights: &'a [W]) -> Self {
108        Resampler { rng, weights }
109    }
110}
111
112impl<'a, R: Rng, W: SampleUniform + Float> IntoIterator for Resampler<'a, R, W> {
113    type Item = usize;
114    type IntoIter = ResampleIterator<'a, R, W>;
115
116    fn into_iter(mut self) -> Self::IntoIter {
117        let mut max_w = W::zero();
118        // Can we do this more elegant given floats are not Ord?
119        for &w in self.weights.iter() {
120            if w > max_w {
121                max_w = w;
122            }
123        }
124
125        let uniform_n = Uniform::new(0, self.weights.len());
126        let uniform_w = Uniform::new(W::zero(), W::from(2.0).unwrap() * max_w);
127
128        ResampleIterator {
129            b: W::zero(),
130            uniform_w,
131            index: uniform_n.sample(&mut self.rng),
132            resampler: self,
133        }
134    }
135}
136
137pub struct ResampleIterator<'a, R: Rng, W: SampleUniform + Float> {
138    b: W,
139    uniform_w: Uniform<W>,
140    index: usize,
141    resampler: Resampler<'a, R, W>,
142}
143
144impl<'a, R: Rng, W: SampleUniform + Float> Iterator for ResampleIterator<'a, R, W> {
145    type Item = usize;
146
147    fn next(&mut self) -> Option<usize> {
148        self.b = self.b + self.uniform_w.sample(self.resampler.rng);
149        while self.b > self.resampler.weights[self.index] {
150            self.b = self.b - self.resampler.weights[self.index];
151            self.index = (self.index + 1) % self.resampler.weights.len();
152        }
153
154        Some(self.index)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    #[test]
161    fn resample_idx() {
162        let mut rng = rand::thread_rng();
163        let weights = [0.1, 0.2, 0.3, 0.8];
164
165        // Make sure we can pull fewer samples than weights
166        let sample_idx_2 = super::resample_idx(&mut rng, &weights, 2);
167
168        assert_eq!(sample_idx_2.len(), 2);
169        assert!(sample_idx_2.iter().all(|&i| i < weights.len()));
170
171        // Make sure we can pull more samples than weights
172        let sample_idx_6 = super::resample_idx(&mut rng, &weights, 6);
173
174        assert_eq!(sample_idx_6.len(), 6);
175        assert!(sample_idx_6.iter().all(|&i| i < weights.len()));
176    }
177
178    #[test]
179    fn resample_iter() {
180        let mut rng = rand::thread_rng();
181        let weights = [0.1, 0.2, 0.3, 0.8];
182
183        let resampler = super::Resampler::new(&mut rng, &weights);
184
185        let samples = resampler.into_iter().take(4).collect::<Vec<usize>>();
186
187        dbg! { &samples};
188
189        assert_eq!(samples.len(), 4);
190        assert!(samples.iter().all(|&i| i < weights.len()));
191    }
192}