1#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9use scirs2_core::RngExt;
11
12use super::core::{rng_utils, Sampler, SamplerIterator};
13
14#[derive(Debug, Clone)]
35pub struct ImportanceSampler {
36 importance_weights: Vec<f64>,
37 num_samples: usize,
38 replacement: bool,
39 temperature: f64,
40 generator: Option<u64>,
41}
42
43impl ImportanceSampler {
44 pub fn new(importance_weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
65 assert!(
67 !importance_weights.is_empty() || num_samples == 0,
68 "importance_weights cannot be empty when num_samples > 0"
69 );
70 assert!(
71 importance_weights.iter().all(|&w| w >= 0.0),
72 "importance_weights must be non-negative"
73 );
74
75 if !importance_weights.is_empty() {
77 let weight_sum: f64 = importance_weights.iter().sum();
78 assert!(
79 weight_sum > 0.0 && weight_sum.is_finite(),
80 "importance_weights must sum to a positive finite value"
81 );
82 }
83
84 let clamped_num_samples = if !replacement {
86 num_samples.min(importance_weights.len())
87 } else {
88 num_samples
89 };
90
91 Self {
92 importance_weights,
93 num_samples: clamped_num_samples,
94 replacement,
95 temperature: 1.0,
96 generator: None,
97 }
98 }
99
100 pub fn with_temperature(mut self, temperature: f64) -> Self {
121 assert!(temperature > 0.0, "temperature must be positive");
122 self.temperature = temperature;
123 self
124 }
125
126 pub fn with_generator(mut self, seed: u64) -> Self {
132 self.generator = Some(seed);
133 self
134 }
135
136 pub fn importance_weights(&self) -> &[f64] {
138 &self.importance_weights
139 }
140
141 pub fn num_samples(&self) -> usize {
143 self.num_samples
144 }
145
146 pub fn replacement(&self) -> bool {
148 self.replacement
149 }
150
151 pub fn temperature(&self) -> f64 {
153 self.temperature
154 }
155
156 pub fn generator(&self) -> Option<u64> {
158 self.generator
159 }
160
161 pub fn update_weights(&mut self, new_weights: Vec<f64>) {
171 assert_eq!(
172 new_weights.len(),
173 self.importance_weights.len(),
174 "New weights must have same length as original weights"
175 );
176 assert!(
177 new_weights.iter().all(|&w| w >= 0.0),
178 "importance_weights must be non-negative"
179 );
180
181 let weight_sum: f64 = new_weights.iter().sum();
182 assert!(
183 weight_sum > 0.0 && weight_sum.is_finite(),
184 "importance_weights must sum to a positive finite value"
185 );
186
187 self.importance_weights = new_weights;
188 }
189
190 fn get_scaled_weights(&self) -> Vec<f64> {
192 if (self.temperature - 1.0).abs() < f64::EPSILON {
193 return self.importance_weights.clone();
194 }
195
196 let max_weight = self
198 .importance_weights
199 .iter()
200 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
201
202 self.importance_weights
203 .iter()
204 .map(|&w| ((w - max_weight) / self.temperature).exp())
205 .collect()
206 }
207
208 fn sample_with_replacement(&self) -> Vec<usize> {
210 let mut rng = rng_utils::create_rng(self.generator);
212
213 let scaled_weights = self.get_scaled_weights();
214 let weight_sum: f64 = scaled_weights.iter().sum();
215
216 let mut cumulative_weights = Vec::with_capacity(scaled_weights.len());
218 let mut cumsum = 0.0;
219
220 for &weight in &scaled_weights {
221 cumsum += weight / weight_sum;
222 cumulative_weights.push(cumsum);
223 }
224
225 if let Some(last) = cumulative_weights.last_mut() {
227 *last = 1.0;
228 }
229
230 (0..self.num_samples)
232 .map(|_| {
233 let rand_val: f64 = rng.random();
234 cumulative_weights
235 .binary_search_by(|&x| {
236 x.partial_cmp(&rand_val)
237 .unwrap_or(std::cmp::Ordering::Equal)
238 })
239 .unwrap_or_else(|i| i)
240 .min(self.importance_weights.len() - 1)
241 })
242 .collect()
243 }
244
245 fn sample_without_replacement(&self) -> Vec<usize> {
247 if self.num_samples >= self.importance_weights.len() {
248 return (0..self.importance_weights.len()).collect();
250 }
251
252 let mut rng = rng_utils::create_rng(self.generator);
254
255 let scaled_weights = self.get_scaled_weights();
256 let mut selected_indices = Vec::new();
257 let mut remaining_weights = scaled_weights;
258 let mut remaining_indices: Vec<usize> = (0..self.importance_weights.len()).collect();
259
260 for _ in 0..self.num_samples {
261 if remaining_indices.is_empty() {
262 break;
263 }
264
265 let weight_sum: f64 = remaining_weights.iter().sum();
267 if weight_sum <= 0.0 {
268 break;
269 }
270
271 let mut cumsum = 0.0;
272 let rand_val: f64 = rng.random::<f64>() * weight_sum;
273
274 let mut selected_idx = 0;
275 for (i, &weight) in remaining_weights.iter().enumerate() {
276 cumsum += weight;
277 if cumsum >= rand_val {
278 selected_idx = i;
279 break;
280 }
281 }
282
283 selected_indices.push(remaining_indices[selected_idx]);
285
286 remaining_indices.remove(selected_idx);
288 remaining_weights.remove(selected_idx);
289 }
290
291 selected_indices
292 }
293
294 pub fn sampling_stats(&self) -> ImportanceStats {
296 let scaled_weights = self.get_scaled_weights();
297 let weight_sum: f64 = scaled_weights.iter().sum();
298 let mean_weight = weight_sum / scaled_weights.len() as f64;
299
300 let variance = scaled_weights
301 .iter()
302 .map(|&w| (w - mean_weight).powi(2))
303 .sum::<f64>()
304 / scaled_weights.len() as f64;
305
306 let max_weight = scaled_weights
307 .iter()
308 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
309 let min_weight = scaled_weights.iter().fold(f64::INFINITY, |a, &b| a.min(b));
310
311 ImportanceStats {
312 num_samples: self.num_samples,
313 total_items: self.importance_weights.len(),
314 replacement: self.replacement,
315 temperature: self.temperature,
316 mean_weight,
317 weight_variance: variance,
318 weight_range: max_weight - min_weight,
319 weight_ratio: if min_weight > 0.0 {
320 max_weight / min_weight
321 } else {
322 f64::INFINITY
323 },
324 }
325 }
326}
327
328impl Sampler for ImportanceSampler {
329 type Iter = SamplerIterator;
330
331 fn iter(&self) -> Self::Iter {
332 let indices = if self.replacement {
333 self.sample_with_replacement()
334 } else {
335 self.sample_without_replacement()
336 };
337
338 SamplerIterator::new(indices)
339 }
340
341 fn len(&self) -> usize {
342 if self.replacement {
343 self.num_samples
344 } else {
345 self.num_samples.min(self.importance_weights.len())
346 }
347 }
348}
349
350#[derive(Debug, Clone, PartialEq)]
352pub struct ImportanceStats {
353 pub num_samples: usize,
355 pub total_items: usize,
357 pub replacement: bool,
359 pub temperature: f64,
361 pub mean_weight: f64,
363 pub weight_variance: f64,
365 pub weight_range: f64,
367 pub weight_ratio: f64,
369}
370
371pub fn uniform_importance_sampler(
383 dataset_size: usize,
384 num_samples: usize,
385 replacement: bool,
386 seed: Option<u64>,
387) -> ImportanceSampler {
388 let weights = vec![1.0; dataset_size];
389 let mut sampler = ImportanceSampler::new(weights, num_samples, replacement);
390 if let Some(s) = seed {
391 sampler = sampler.with_generator(s);
392 }
393 sampler
394}
395
396pub fn class_balanced_importance_sampler(
408 class_labels: &[usize],
409 num_samples: usize,
410 replacement: bool,
411 seed: Option<u64>,
412) -> ImportanceSampler {
413 let max_class = class_labels.iter().max().copied().unwrap_or(0);
415 let mut class_counts = vec![0usize; max_class + 1];
416
417 for &label in class_labels {
418 if label <= max_class {
419 class_counts[label] += 1;
420 }
421 }
422
423 let total_samples = class_labels.len() as f64;
425 let num_classes = class_counts.len() as f64;
426
427 let weights: Vec<f64> = class_labels
428 .iter()
429 .map(|&label| {
430 let class_count = class_counts[label];
431 if class_count > 0 {
432 total_samples / (num_classes * class_count as f64)
433 } else {
434 1.0
435 }
436 })
437 .collect();
438
439 let mut sampler = ImportanceSampler::new(weights, num_samples, replacement);
440 if let Some(s) = seed {
441 sampler = sampler.with_generator(s);
442 }
443 sampler
444}
445
446pub fn loss_based_importance_sampler(
459 losses: &[f64],
460 num_samples: usize,
461 replacement: bool,
462 power: f64,
463 seed: Option<u64>,
464) -> ImportanceSampler {
465 let weights: Vec<f64> = losses
466 .iter()
467 .map(|&loss| loss.max(1e-6).powf(power)) .collect();
469
470 let mut sampler = ImportanceSampler::new(weights, num_samples, replacement);
471 if let Some(s) = seed {
472 sampler = sampler.with_generator(s);
473 }
474 sampler
475}
476
477pub fn exponential_importance_sampler(
490 scores: &[f64],
491 num_samples: usize,
492 replacement: bool,
493 scale: f64,
494 seed: Option<u64>,
495) -> ImportanceSampler {
496 let weights: Vec<f64> = scores.iter().map(|&score| (score * scale).exp()).collect();
497
498 let mut sampler = ImportanceSampler::new(weights, num_samples, replacement);
499 if let Some(s) = seed {
500 sampler = sampler.with_generator(s);
501 }
502 sampler
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_importance_sampler_basic() {
511 let importance_weights = vec![0.1, 0.5, 1.0, 0.3, 0.8];
512 let sampler = ImportanceSampler::new(importance_weights.clone(), 3, true)
513 .with_temperature(1.0)
514 .with_generator(42);
515
516 assert_eq!(sampler.importance_weights(), &importance_weights);
517 assert_eq!(sampler.num_samples(), 3);
518 assert!(sampler.replacement());
519 assert_eq!(sampler.temperature(), 1.0);
520 assert_eq!(sampler.generator(), Some(42));
521
522 let indices: Vec<usize> = sampler.iter().collect();
523 assert_eq!(indices.len(), 3);
524
525 for &idx in &indices {
527 assert!(idx < 5);
528 }
529 }
530
531 #[test]
532 fn test_importance_sampler_without_replacement() {
533 let importance_weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
534 let sampler = ImportanceSampler::new(importance_weights, 3, false).with_generator(42);
535
536 assert!(!sampler.replacement());
537 assert_eq!(sampler.len(), 3);
538
539 let indices: Vec<usize> = sampler.iter().collect();
540 assert_eq!(indices.len(), 3);
541
542 let mut sorted_indices = indices.clone();
544 sorted_indices.sort();
545 sorted_indices.dedup();
546 assert_eq!(sorted_indices.len(), 3);
547
548 for &idx in &indices {
550 assert!(idx < 5);
551 }
552 }
553
554 #[test]
555 fn test_importance_sampler_temperature_scaling() {
556 let importance_weights = vec![1.0, 10.0]; let low_temp_sampler = ImportanceSampler::new(importance_weights.clone(), 10, true)
560 .with_temperature(0.1)
561 .with_generator(42);
562
563 let high_temp_sampler = ImportanceSampler::new(importance_weights.clone(), 10, true)
565 .with_temperature(10.0)
566 .with_generator(42);
567
568 let low_temp_indices: Vec<usize> = low_temp_sampler.iter().collect();
569 let high_temp_indices: Vec<usize> = high_temp_sampler.iter().collect();
570
571 let low_temp_high_weight_count = low_temp_indices.iter().filter(|&&i| i == 1).count();
573 let high_temp_high_weight_count = high_temp_indices.iter().filter(|&&i| i == 1).count();
574
575 assert!(low_temp_high_weight_count >= high_temp_high_weight_count);
577 }
578
579 #[test]
580 fn test_importance_sampler_edge_cases() {
581 let single_weight = vec![1.0];
583 let single_sampler = ImportanceSampler::new(single_weight, 1, false);
584 let indices: Vec<usize> = single_sampler.iter().collect();
585 assert_eq!(indices, vec![0]);
586
587 let zero_sampler = ImportanceSampler::new(vec![1.0, 2.0], 0, true);
589 assert_eq!(zero_sampler.len(), 0);
590 let indices: Vec<usize> = zero_sampler.iter().collect();
591 assert!(indices.is_empty());
592
593 let limited_sampler = ImportanceSampler::new(vec![1.0, 2.0], 5, false);
595 assert_eq!(limited_sampler.len(), 2); let indices: Vec<usize> = limited_sampler.iter().collect();
597 assert_eq!(indices.len(), 2);
598 }
599
600 #[test]
601 fn test_importance_sampler_extreme_weights() {
602 let extreme_weights = vec![0.001, 0.001, 1000.0, 0.001];
604 let sampler = ImportanceSampler::new(extreme_weights, 20, true).with_generator(42);
605
606 let indices: Vec<usize> = sampler.iter().collect();
607 assert_eq!(indices.len(), 20);
608
609 let high_weight_count = indices.iter().filter(|&&i| i == 2).count();
611 assert!(high_weight_count > 10); }
613
614 #[test]
615 fn test_update_weights() {
616 let mut sampler = ImportanceSampler::new(vec![1.0, 2.0, 3.0], 2, true);
617
618 let new_weights = vec![3.0, 1.0, 2.0];
619 sampler.update_weights(new_weights.clone());
620
621 assert_eq!(sampler.importance_weights(), &new_weights);
622 }
623
624 #[test]
625 fn test_sampling_stats() {
626 let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
627 let sampler = ImportanceSampler::new(weights, 3, true);
628
629 let stats = sampler.sampling_stats();
630 assert_eq!(stats.num_samples, 3);
631 assert_eq!(stats.total_items, 5);
632 assert!(stats.replacement);
633 assert_eq!(stats.temperature, 1.0);
634 assert!(stats.mean_weight > 0.0);
635 assert!(stats.weight_variance >= 0.0);
636 assert!(stats.weight_range >= 0.0);
637 assert!(stats.weight_ratio >= 1.0);
638 }
639
640 #[test]
641 fn test_convenience_functions() {
642 let uniform = uniform_importance_sampler(10, 5, true, Some(42));
644 assert_eq!(uniform.importance_weights().len(), 10);
645 assert!(uniform.importance_weights().iter().all(|&w| w == 1.0));
646
647 let class_labels = vec![0, 0, 0, 1, 1, 2]; let balanced = class_balanced_importance_sampler(&class_labels, 4, true, Some(42));
650 let weights = balanced.importance_weights();
651
652 assert!(weights[5] > weights[3]); assert!(weights[3] > weights[0]); let losses = vec![0.1, 0.8, 0.3, 0.9, 0.2];
659 let loss_based = loss_based_importance_sampler(&losses, 3, true, 1.0, Some(42));
660 let weights = loss_based.importance_weights();
661
662 assert!(weights[3] > weights[2]); assert!(weights[1] > weights[0]); let scores = vec![1.0, 2.0, 3.0];
668 let exponential = exponential_importance_sampler(&scores, 2, true, 1.0, Some(42));
669 let weights = exponential.importance_weights();
670
671 assert!(weights[2] > weights[1]);
673 assert!(weights[1] > weights[0]);
674 }
675
676 #[test]
677 fn test_scaled_weights() {
678 let weights = vec![1.0, 2.0, 3.0];
679 let sampler = ImportanceSampler::new(weights.clone(), 2, true);
680
681 let scaled_1 = sampler.get_scaled_weights();
683 assert_eq!(scaled_1, weights);
684
685 let sampler_low = sampler.clone().with_temperature(0.5);
687 let scaled_low = sampler_low.get_scaled_weights();
688
689 let sampler_high = sampler.clone().with_temperature(2.0);
691 let scaled_high = sampler_high.get_scaled_weights();
692
693 assert_ne!(scaled_low, weights);
695 assert_ne!(scaled_high, weights);
696 }
697
698 #[test]
699 fn test_reproducibility() {
700 let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
701 let sampler1 = ImportanceSampler::new(weights.clone(), 3, true).with_generator(123);
702 let sampler2 = ImportanceSampler::new(weights, 3, true).with_generator(123);
703
704 let indices1: Vec<usize> = sampler1.iter().collect();
705 let indices2: Vec<usize> = sampler2.iter().collect();
706
707 assert_eq!(indices1, indices2);
708 }
709
710 #[test]
711 #[should_panic(expected = "importance_weights cannot be empty")]
712 fn test_empty_weights() {
713 ImportanceSampler::new(vec![], 5, true);
714 }
715
716 #[test]
717 #[should_panic(expected = "importance_weights must be non-negative")]
718 fn test_negative_weights() {
719 ImportanceSampler::new(vec![1.0, -1.0, 2.0], 3, true);
720 }
721
722 #[test]
723 #[should_panic(expected = "importance_weights must sum to a positive finite value")]
724 fn test_zero_sum_weights() {
725 ImportanceSampler::new(vec![0.0, 0.0, 0.0], 2, true);
726 }
727
728 #[test]
729 fn test_invalid_no_replacement() {
730 let sampler = ImportanceSampler::new(vec![1.0, 2.0], 5, false);
732 assert_eq!(sampler.len(), 2); }
734
735 #[test]
736 #[should_panic(expected = "temperature must be positive")]
737 fn test_invalid_temperature() {
738 ImportanceSampler::new(vec![1.0, 2.0], 1, true).with_temperature(0.0);
739 }
740
741 #[test]
742 #[should_panic(expected = "New weights must have same length")]
743 fn test_update_weights_wrong_size() {
744 let mut sampler = ImportanceSampler::new(vec![1.0, 2.0, 3.0], 2, true);
745 sampler.update_weights(vec![1.0, 2.0]); }
747
748 #[test]
749 fn test_class_balanced_edge_cases() {
750 let balanced_empty = class_balanced_importance_sampler(&[], 0, true, None);
752 assert!(balanced_empty.importance_weights().is_empty());
753
754 let single_class = vec![0, 0, 0];
756 let balanced_single = class_balanced_importance_sampler(&single_class, 2, true, None);
757 let weights = balanced_single.importance_weights();
758 assert!(weights.iter().all(|&w| w > 0.0));
759 assert!((weights[0] - weights[1]).abs() < f64::EPSILON);
761
762 let large_classes = vec![0, 100, 5];
764 let balanced_large = class_balanced_importance_sampler(&large_classes, 2, true, None);
765 assert_eq!(balanced_large.importance_weights().len(), 3);
766 }
767
768 #[test]
769 fn test_loss_based_edge_cases() {
770 let zero_losses = vec![0.0, 0.0, 0.0];
772 let loss_sampler = loss_based_importance_sampler(&zero_losses, 2, true, 1.0, None);
773 let weights = loss_sampler.importance_weights();
774 assert!(weights.iter().all(|&w| w > 0.0)); let extreme_losses = vec![1e-10, 1e10];
778 let extreme_sampler = loss_based_importance_sampler(&extreme_losses, 1, true, 1.0, None);
779 let weights = extreme_sampler.importance_weights();
780 assert!(weights[1] > weights[0]); }
782
783 #[test]
784 fn test_exponential_edge_cases() {
785 let negative_scores = vec![-1.0, 0.0, 1.0];
787 let exp_sampler = exponential_importance_sampler(&negative_scores, 2, true, 1.0, None);
788 let weights = exp_sampler.importance_weights();
789 assert!(weights.iter().all(|&w| w > 0.0)); assert!(weights[2] > weights[1]); assert!(weights[1] > weights[0]); let scores = vec![1.0, 2.0];
795 let large_scale = exponential_importance_sampler(&scores, 1, true, 10.0, None);
796 let weights = large_scale.importance_weights();
797 assert!(weights[1] > weights[0]); }
799}