1#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9use scirs2_core::RngExt;
11
12use super::core::{rng_utils, Sampler, SamplerIterator};
13
14#[derive(Debug, Clone)]
33pub struct WeightedRandomSampler {
34 weights: Vec<f64>,
35 num_samples: usize,
36 replacement: bool,
37 generator: Option<u64>,
38}
39
40impl WeightedRandomSampler {
41 pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
53 assert!(!weights.is_empty(), "weights cannot be empty");
54 assert!(
55 weights.iter().all(|&w| w >= 0.0),
56 "weights must be non-negative"
57 );
58 let weight_sum: f64 = weights.iter().sum();
59 assert!(
60 weight_sum > 0.0 && weight_sum.is_finite(),
61 "weights must sum to a positive finite value, got {weight_sum}"
62 );
63
64 Self {
65 weights,
66 num_samples,
67 replacement,
68 generator: None,
69 }
70 }
71
72 pub fn with_generator(mut self, seed: u64) -> Self {
78 self.generator = Some(seed);
79 self
80 }
81
82 pub fn weights(&self) -> &[f64] {
84 &self.weights
85 }
86
87 pub fn num_samples(&self) -> usize {
89 self.num_samples
90 }
91
92 pub fn replacement(&self) -> bool {
94 self.replacement
95 }
96
97 pub fn generator(&self) -> Option<u64> {
99 self.generator
100 }
101
102 fn sample_with_replacement(&self) -> Vec<usize> {
104 let mut rng = rng_utils::create_rng(self.generator);
106
107 let weight_sum: f64 = self.weights.iter().sum();
109 let mut cumulative_weights = Vec::with_capacity(self.weights.len());
110 let mut cumsum = 0.0;
111
112 for &weight in &self.weights {
113 cumsum += weight / weight_sum;
114 cumulative_weights.push(cumsum);
115 }
116
117 if let Some(last) = cumulative_weights.last_mut() {
119 *last = 1.0;
120 }
121
122 (0..self.num_samples)
124 .map(|_| {
125 let rand_val: f64 = rng.random();
126 cumulative_weights
128 .binary_search_by(|&x| {
129 x.partial_cmp(&rand_val)
130 .unwrap_or(std::cmp::Ordering::Equal)
131 })
132 .unwrap_or_else(|i| i)
133 .min(self.weights.len() - 1)
134 })
135 .collect()
136 }
137
138 fn sample_without_replacement(&self) -> Vec<usize> {
140 if self.num_samples >= self.weights.len() {
141 return (0..self.weights.len()).collect();
143 }
144
145 let mut rng = rng_utils::create_rng(self.generator);
147
148 let mut selected_indices = Vec::new();
150 let mut remaining_weights = self.weights.clone();
151 let mut remaining_indices: Vec<usize> = (0..self.weights.len()).collect();
152
153 for _ in 0..self.num_samples {
154 if remaining_indices.is_empty() {
155 break;
156 }
157
158 let weight_sum: f64 = remaining_weights.iter().sum();
160 if weight_sum <= 0.0 {
161 break;
162 }
163
164 let mut cumsum = 0.0;
165 let rand_val: f64 = rng.random::<f64>() * weight_sum;
166
167 let mut selected_idx = 0;
168 for (i, &weight) in remaining_weights.iter().enumerate() {
169 cumsum += weight;
170 if cumsum >= rand_val {
171 selected_idx = i;
172 break;
173 }
174 }
175
176 selected_indices.push(remaining_indices[selected_idx]);
178
179 remaining_indices.remove(selected_idx);
181 remaining_weights.remove(selected_idx);
182 }
183
184 selected_indices
185 }
186}
187
188impl Sampler for WeightedRandomSampler {
189 type Iter = SamplerIterator;
190
191 fn iter(&self) -> Self::Iter {
192 let indices = if self.replacement {
193 self.sample_with_replacement()
194 } else {
195 self.sample_without_replacement()
196 };
197
198 SamplerIterator::new(indices)
199 }
200
201 fn len(&self) -> usize {
202 self.num_samples
203 }
204}
205
206#[derive(Debug, Clone)]
225pub struct SubsetRandomSampler {
226 indices: Vec<usize>,
227 generator: Option<u64>,
228}
229
230impl SubsetRandomSampler {
231 pub fn new(indices: Vec<usize>) -> Self {
237 Self {
238 indices,
239 generator: None,
240 }
241 }
242
243 pub fn with_generator(mut self, seed: u64) -> Self {
249 self.generator = Some(seed);
250 self
251 }
252
253 pub fn indices(&self) -> &[usize] {
255 &self.indices
256 }
257
258 pub fn generator(&self) -> Option<u64> {
260 self.generator
261 }
262}
263
264impl Sampler for SubsetRandomSampler {
265 type Iter = SamplerIterator;
266
267 fn iter(&self) -> Self::Iter {
268 let mut shuffled_indices = self.indices.clone();
269 rng_utils::shuffle_indices(&mut shuffled_indices, self.generator);
270 SamplerIterator::new(shuffled_indices)
271 }
272
273 fn len(&self) -> usize {
274 self.indices.len()
275 }
276}
277
278pub fn weighted_random(
289 weights: Vec<f64>,
290 num_samples: usize,
291 replacement: bool,
292 seed: Option<u64>,
293) -> WeightedRandomSampler {
294 let mut sampler = WeightedRandomSampler::new(weights, num_samples, replacement);
295 if let Some(s) = seed {
296 sampler = sampler.with_generator(s);
297 }
298 sampler
299}
300
301pub fn subset_random(indices: Vec<usize>, seed: Option<u64>) -> SubsetRandomSampler {
310 let mut sampler = SubsetRandomSampler::new(indices);
311 if let Some(s) = seed {
312 sampler = sampler.with_generator(s);
313 }
314 sampler
315}
316
317pub fn balanced_weighted(
328 class_counts: &[usize],
329 num_samples: usize,
330 seed: Option<u64>,
331) -> WeightedRandomSampler {
332 let total_samples: usize = class_counts.iter().sum();
333 let num_classes = class_counts.len();
334
335 let weights: Vec<f64> = class_counts
337 .iter()
338 .map(|&count| {
339 if count > 0 {
340 total_samples as f64 / (num_classes as f64 * count as f64)
341 } else {
342 0.0
343 }
344 })
345 .collect();
346
347 weighted_random(weights, num_samples, true, seed)
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_weighted_sampler_basic() {
356 let weights = vec![0.1, 0.1, 0.1, 0.1, 0.6]; let sampler = WeightedRandomSampler::new(weights.clone(), 100, true).with_generator(42);
358
359 assert_eq!(sampler.len(), 100);
360 assert_eq!(sampler.weights(), &weights);
361 assert_eq!(sampler.num_samples(), 100);
362 assert!(sampler.replacement());
363 assert_eq!(sampler.generator(), Some(42));
364
365 let indices: Vec<usize> = sampler.iter().collect();
366 assert_eq!(indices.len(), 100);
367
368 for &idx in &indices {
370 assert!(idx < 5);
371 }
372 }
373
374 #[test]
375 fn test_weighted_sampler_without_replacement() {
376 let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
377 let sampler = WeightedRandomSampler::new(weights, 3, false).with_generator(42);
378
379 assert!(!sampler.replacement());
380
381 let indices: Vec<usize> = sampler.iter().collect();
382 assert_eq!(indices.len(), 3);
383
384 let mut sorted_indices = indices.clone();
386 sorted_indices.sort();
387 sorted_indices.dedup();
388 assert_eq!(sorted_indices.len(), 3);
389
390 for &idx in &indices {
392 assert!(idx < 5);
393 }
394 }
395
396 #[test]
397 fn test_weighted_sampler_uniform_weights() {
398 let weights = vec![1.0; 10];
399 let sampler = WeightedRandomSampler::new(weights, 50, true).with_generator(42);
400
401 let indices: Vec<usize> = sampler.iter().collect();
402 assert_eq!(indices.len(), 50);
403
404 let mut counts = [0; 10];
406 for &idx in &indices {
407 counts[idx] += 1;
408 }
409
410 for count in counts {
412 assert!(count > 0);
413 }
414 }
415
416 #[test]
417 fn test_weighted_sampler_extreme_weights() {
418 let weights = vec![0.0, 0.0, 0.0, 1.0]; let sampler = WeightedRandomSampler::new(weights, 10, true).with_generator(42);
420
421 let indices: Vec<usize> = sampler.iter().collect();
422 assert_eq!(indices.len(), 10);
423
424 for &idx in &indices {
426 assert_eq!(idx, 3);
427 }
428 }
429
430 #[test]
431 #[should_panic(expected = "weights cannot be empty")]
432 fn test_weighted_sampler_empty_weights() {
433 WeightedRandomSampler::new(vec![], 10, true);
434 }
435
436 #[test]
437 #[should_panic(expected = "weights must be non-negative")]
438 fn test_weighted_sampler_negative_weights() {
439 WeightedRandomSampler::new(vec![1.0, -1.0, 1.0], 10, true);
440 }
441
442 #[test]
443 #[should_panic(expected = "weights must sum to a positive finite value")]
444 fn test_weighted_sampler_zero_sum() {
445 WeightedRandomSampler::new(vec![0.0, 0.0, 0.0], 10, true);
446 }
447
448 #[test]
449 fn test_subset_random_sampler() {
450 let subset_indices = vec![1, 3, 5, 7, 9];
452 let sampler = SubsetRandomSampler::new(subset_indices.clone()).with_generator(42);
453
454 assert_eq!(sampler.len(), 5);
455 assert_eq!(sampler.indices(), &subset_indices);
456 assert_eq!(sampler.generator(), Some(42));
457
458 let sampled_indices: Vec<usize> = sampler.iter().collect();
459 assert_eq!(sampled_indices.len(), 5);
460
461 for idx in &sampled_indices {
463 assert!(subset_indices.contains(idx));
464 }
465
466 let mut sorted_sampled = sampled_indices.clone();
468 sorted_sampled.sort();
469 let mut sorted_original = subset_indices;
470 sorted_original.sort();
471 assert_eq!(sorted_sampled, sorted_original);
472 }
473
474 #[test]
475 fn test_subset_random_sampler_empty() {
476 let sampler = SubsetRandomSampler::new(vec![]);
477 assert_eq!(sampler.len(), 0);
478 assert!(sampler.is_empty());
479
480 let indices: Vec<usize> = sampler.iter().collect();
481 assert!(indices.is_empty());
482 }
483
484 #[test]
485 fn test_subset_random_sampler_single() {
486 let sampler = SubsetRandomSampler::new(vec![42]);
487 assert_eq!(sampler.len(), 1);
488
489 let indices: Vec<usize> = sampler.iter().collect();
490 assert_eq!(indices, vec![42]);
491 }
492
493 #[test]
494 fn test_subset_random_sampler_reproducible() {
495 let subset_indices = vec![10, 20, 30, 40, 50];
496 let sampler1 = SubsetRandomSampler::new(subset_indices.clone()).with_generator(123);
497 let sampler2 = SubsetRandomSampler::new(subset_indices).with_generator(123);
498
499 let indices1: Vec<usize> = sampler1.iter().collect();
500 let indices2: Vec<usize> = sampler2.iter().collect();
501
502 assert_eq!(indices1, indices2);
503 }
504
505 #[test]
506 fn test_convenience_functions() {
507 let weights = vec![1.0, 2.0, 3.0];
509 let weighted = weighted_random(weights.clone(), 10, true, Some(42));
510 assert_eq!(weighted.weights(), &weights);
511 assert_eq!(weighted.num_samples(), 10);
512 assert!(weighted.replacement());
513 assert_eq!(weighted.generator(), Some(42));
514
515 let indices = vec![1, 3, 5];
517 let subset = subset_random(indices.clone(), Some(42));
518 assert_eq!(subset.indices(), &indices);
519 assert_eq!(subset.generator(), Some(42));
520
521 let class_counts = vec![100, 50, 25]; let balanced = balanced_weighted(&class_counts, 30, Some(42));
524 assert_eq!(balanced.num_samples(), 30);
525 assert!(balanced.replacement());
526
527 let weights = balanced.weights();
529 assert!(weights[2] > weights[1]); assert!(weights[1] > weights[0]); }
532
533 #[test]
534 fn test_balanced_weighted_edge_cases() {
535 let class_counts = vec![100, 0, 50];
537 let balanced = balanced_weighted(&class_counts, 20, Some(42));
538 let weights = balanced.weights();
539
540 assert!(weights[0] > 0.0);
541 assert_eq!(weights[1], 0.0); assert!(weights[2] > 0.0);
543 assert!(weights[2] > weights[0]); let class_counts = vec![100];
547 let balanced = balanced_weighted(&class_counts, 10, Some(42));
548 assert_eq!(balanced.weights().len(), 1);
549 assert!(balanced.weights()[0] > 0.0);
550 }
551
552 #[test]
553 fn test_weighted_sampler_clone() {
554 let weights = vec![1.0, 2.0, 3.0];
555 let sampler = WeightedRandomSampler::new(weights.clone(), 10, true).with_generator(42);
556 let cloned = sampler.clone();
557
558 assert_eq!(sampler.weights(), cloned.weights());
559 assert_eq!(sampler.num_samples(), cloned.num_samples());
560 assert_eq!(sampler.replacement(), cloned.replacement());
561 assert_eq!(sampler.generator(), cloned.generator());
562 }
563
564 #[test]
565 fn test_subset_sampler_clone() {
566 let indices = vec![1, 3, 5, 7];
567 let sampler = SubsetRandomSampler::new(indices.clone()).with_generator(42);
568 let cloned = sampler.clone();
569
570 assert_eq!(sampler.indices(), cloned.indices());
571 assert_eq!(sampler.generator(), cloned.generator());
572 }
573}