1use num_traits::float::Float;
19use rand::{
20 distributions::{uniform::SampleUniform, Distribution},
21 Rng,
22};
23
24use rand_distr::Uniform;
25
26pub fn resample_idx<R, W>(rng: &mut R, weights: &[W], n: usize) -> Vec<usize>
43where
44 R: Rng,
45 W: SampleUniform + Float,
46{
47 let resampler = Resampler::new(rng, weights);
48 resampler.into_iter().take(n).collect()
49}
50
51pub fn resample<R, T, W>(rng: &mut R, weights: &[W], population: &[T]) -> Vec<T>
69where
70 R: Rng,
71 T: Clone,
72 W: SampleUniform + Float,
73{
74 let indices = resample_idx(rng, weights, population.len());
75
76 indices.iter().map(|&i| population[i].clone()).collect()
77}
78
79pub struct Resampler<'a, R: Rng, W: Float> {
101 rng: &'a mut R,
102 weights: &'a [W],
103}
104
105impl<'a, R: Rng, W: Float> Resampler<'a, R, W> {
106 pub fn new(rng: &'a mut R, weights: &'a [W]) -> Self {
108 Resampler { rng, weights }
109 }
110}
111
112impl<'a, R: Rng, W: SampleUniform + Float> IntoIterator for Resampler<'a, R, W> {
113 type Item = usize;
114 type IntoIter = ResampleIterator<'a, R, W>;
115
116 fn into_iter(mut self) -> Self::IntoIter {
117 let mut max_w = W::zero();
118 for &w in self.weights.iter() {
120 if w > max_w {
121 max_w = w;
122 }
123 }
124
125 let uniform_n = Uniform::new(0, self.weights.len());
126 let uniform_w = Uniform::new(W::zero(), W::from(2.0).unwrap() * max_w);
127
128 ResampleIterator {
129 b: W::zero(),
130 uniform_w,
131 index: uniform_n.sample(&mut self.rng),
132 resampler: self,
133 }
134 }
135}
136
137pub struct ResampleIterator<'a, R: Rng, W: SampleUniform + Float> {
138 b: W,
139 uniform_w: Uniform<W>,
140 index: usize,
141 resampler: Resampler<'a, R, W>,
142}
143
144impl<'a, R: Rng, W: SampleUniform + Float> Iterator for ResampleIterator<'a, R, W> {
145 type Item = usize;
146
147 fn next(&mut self) -> Option<usize> {
148 self.b = self.b + self.uniform_w.sample(self.resampler.rng);
149 while self.b > self.resampler.weights[self.index] {
150 self.b = self.b - self.resampler.weights[self.index];
151 self.index = (self.index + 1) % self.resampler.weights.len();
152 }
153
154 Some(self.index)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 #[test]
161 fn resample_idx() {
162 let mut rng = rand::thread_rng();
163 let weights = [0.1, 0.2, 0.3, 0.8];
164
165 let sample_idx_2 = super::resample_idx(&mut rng, &weights, 2);
167
168 assert_eq!(sample_idx_2.len(), 2);
169 assert!(sample_idx_2.iter().all(|&i| i < weights.len()));
170
171 let sample_idx_6 = super::resample_idx(&mut rng, &weights, 6);
173
174 assert_eq!(sample_idx_6.len(), 6);
175 assert!(sample_idx_6.iter().all(|&i| i < weights.len()));
176 }
177
178 #[test]
179 fn resample_iter() {
180 let mut rng = rand::thread_rng();
181 let weights = [0.1, 0.2, 0.3, 0.8];
182
183 let resampler = super::Resampler::new(&mut rng, &weights);
184
185 let samples = resampler.into_iter().take(4).collect::<Vec<usize>>();
186
187 dbg! { &samples};
188
189 assert_eq!(samples.len(), 4);
190 assert!(samples.iter().all(|&i| i < weights.len()));
191 }
192}