1use rand::Rng;
13
14pub fn create_rng(seed: u64) -> rand::rngs::SmallRng {
28 use rand::SeedableRng;
29 rand::rngs::SmallRng::seed_from_u64(seed)
30}
31
32pub fn shuffle<T, R: Rng>(slice: &mut [T], rng: &mut R) {
58 let n = slice.len();
59 if n <= 1 {
60 return;
61 }
62 for i in (1..n).rev() {
63 let j = rng.random_range(0..=i);
64 slice.swap(i, j);
65 }
66}
67
68pub fn shuffled_indices<R: Rng>(n: usize, rng: &mut R) -> Vec<usize> {
88 let mut indices: Vec<usize> = (0..n).collect();
89 shuffle(&mut indices, rng);
90 indices
91}
92
93pub fn weighted_choose<R: Rng>(weights: &[f64], rng: &mut R) -> Option<usize> {
113 if weights.is_empty() {
114 return None;
115 }
116
117 let total: f64 = weights.iter().filter(|w| **w > 0.0).sum();
118 if total <= 0.0 {
119 return None;
120 }
121
122 let threshold = rng.random_range(0.0..total);
123 let mut cumulative = 0.0;
124 for (i, &w) in weights.iter().enumerate() {
125 if w > 0.0 {
126 cumulative += w;
127 if cumulative > threshold {
128 return Some(i);
129 }
130 }
131 }
132
133 Some(weights.len() - 1)
135}
136
137pub struct WeightedSampler {
159 cumulative: Vec<f64>,
160 total: f64,
161}
162
163impl WeightedSampler {
164 pub fn new(weights: &[f64]) -> Option<Self> {
169 if weights.is_empty() {
170 return None;
171 }
172
173 let mut cumulative = Vec::with_capacity(weights.len());
174 let mut total = 0.0;
175 for &w in weights {
176 if w > 0.0 {
177 total += w;
178 }
179 cumulative.push(total);
180 }
181
182 if total <= 0.0 {
183 return None;
184 }
185
186 Some(Self { cumulative, total })
187 }
188
189 pub fn sample<R: Rng>(&self, rng: &mut R) -> usize {
194 let threshold = rng.random_range(0.0..self.total);
195 match self.cumulative.binary_search_by(|c| {
196 c.partial_cmp(&threshold)
197 .expect("cumulative values are finite")
198 }) {
199 Ok(i) => i,
200 Err(i) => i.min(self.cumulative.len() - 1),
201 }
202 }
203
204 pub fn len(&self) -> usize {
206 self.cumulative.len()
207 }
208
209 pub fn is_empty(&self) -> bool {
211 self.cumulative.is_empty()
212 }
213
214 pub fn total_weight(&self) -> f64 {
216 self.total
217 }
218}
219
220#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_create_rng_deterministic() {
230 let mut rng1 = create_rng(42);
231 let mut rng2 = create_rng(42);
232 let vals1: Vec<f64> = (0..10).map(|_| rng1.random()).collect();
233 let vals2: Vec<f64> = (0..10).map(|_| rng2.random()).collect();
234 assert_eq!(vals1, vals2);
235 }
236
237 #[test]
238 fn test_shuffle_preserves_elements() {
239 let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
240 let mut rng = create_rng(123);
241 shuffle(&mut v, &mut rng);
242 v.sort();
243 assert_eq!(v, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
244 }
245
246 #[test]
247 fn test_shuffle_empty() {
248 let mut v: Vec<i32> = vec![];
249 let mut rng = create_rng(0);
250 shuffle(&mut v, &mut rng); }
252
253 #[test]
254 fn test_shuffle_single() {
255 let mut v = vec![42];
256 let mut rng = create_rng(0);
257 shuffle(&mut v, &mut rng);
258 assert_eq!(v, vec![42]);
259 }
260
261 #[test]
262 fn test_shuffle_actually_shuffles() {
263 let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
265 let mut v = original.clone();
266 let mut rng = create_rng(42);
267 shuffle(&mut v, &mut rng);
268 assert_ne!(v, original, "shuffle should change order (probabilistic)");
269 }
270
271 #[test]
272 fn test_shuffled_indices() {
273 let mut rng = create_rng(42);
274 let indices = shuffled_indices(10, &mut rng);
275 assert_eq!(indices.len(), 10);
276 let mut sorted = indices.clone();
277 sorted.sort();
278 assert_eq!(sorted, (0..10).collect::<Vec<_>>());
279 }
280
281 #[test]
282 fn test_weighted_choose_basic() {
283 let mut rng = create_rng(42);
284 let weights = [0.0, 0.0, 1.0]; for _ in 0..100 {
286 assert_eq!(weighted_choose(&weights, &mut rng), Some(2));
287 }
288 }
289
290 #[test]
291 fn test_weighted_choose_empty() {
292 let mut rng = create_rng(42);
293 assert_eq!(weighted_choose(&[], &mut rng), None);
294 }
295
296 #[test]
297 fn test_weighted_choose_all_zero() {
298 let mut rng = create_rng(42);
299 assert_eq!(weighted_choose(&[0.0, 0.0], &mut rng), None);
300 }
301
302 #[test]
303 fn test_weighted_choose_distribution() {
304 let mut rng = create_rng(42);
305 let weights = [1.0, 3.0]; let mut counts = [0u32; 2];
307 let n = 10000;
308 for _ in 0..n {
309 let idx = weighted_choose(&weights, &mut rng).unwrap();
310 counts[idx] += 1;
311 }
312 let ratio = counts[1] as f64 / counts[0] as f64;
313 assert!(
314 (ratio - 3.0).abs() < 0.5,
315 "expected ratio ~3.0, got {ratio}"
316 );
317 }
318
319 #[test]
320 fn test_weighted_sampler_basic() {
321 let sampler = WeightedSampler::new(&[1.0, 2.0, 3.0]).unwrap();
322 assert_eq!(sampler.len(), 3);
323 assert!(!sampler.is_empty());
324 assert!((sampler.total_weight() - 6.0).abs() < 1e-15);
325 }
326
327 #[test]
328 fn test_weighted_sampler_deterministic_weight() {
329 let sampler = WeightedSampler::new(&[0.0, 0.0, 1.0]).unwrap();
330 let mut rng = create_rng(42);
331 for _ in 0..100 {
332 assert_eq!(sampler.sample(&mut rng), 2);
333 }
334 }
335
336 #[test]
337 fn test_weighted_sampler_distribution() {
338 let sampler = WeightedSampler::new(&[1.0, 3.0]).unwrap();
339 let mut rng = create_rng(42);
340 let mut counts = [0u32; 2];
341 let n = 10000;
342 for _ in 0..n {
343 counts[sampler.sample(&mut rng)] += 1;
344 }
345 let ratio = counts[1] as f64 / counts[0] as f64;
346 assert!(
347 (ratio - 3.0).abs() < 0.5,
348 "expected ratio ~3.0, got {ratio}"
349 );
350 }
351
352 #[test]
353 fn test_weighted_sampler_empty() {
354 assert!(WeightedSampler::new(&[]).is_none());
355 }
356}
357
358#[cfg(test)]
359mod proptests {
360 use super::*;
361 use proptest::prelude::*;
362
363 proptest! {
364 #![proptest_config(ProptestConfig::with_cases(300))]
365
366 #[test]
367 fn shuffle_is_permutation(
368 seed in 0_u64..10000,
369 data in proptest::collection::vec(0_i32..1000, 0..50),
370 ) {
371 let mut shuffled = data.clone();
372 let mut rng = create_rng(seed);
373 shuffle(&mut shuffled, &mut rng);
374 let mut sorted_orig = data.clone();
375 let mut sorted_shuf = shuffled;
376 sorted_orig.sort();
377 sorted_shuf.sort();
378 prop_assert_eq!(sorted_orig, sorted_shuf);
379 }
380
381 #[test]
382 fn weighted_choose_returns_valid_index(
383 seed in 0_u64..10000,
384 weights in proptest::collection::vec(0.0_f64..10.0, 1..20),
385 ) {
386 let has_positive = weights.iter().any(|&w| w > 0.0);
387 let mut rng = create_rng(seed);
388 let result = weighted_choose(&weights, &mut rng);
389 if has_positive {
390 let idx = result.unwrap();
391 prop_assert!(idx < weights.len());
392 }
393 }
394 }
395}