Skip to main content

torsh_data/sampler/
importance.rs

1//! Importance sampling functionality
2//!
3//! This module provides importance sampling strategies that adjust sampling probabilities
4//! to correct for dataset bias or to emphasize certain types of samples during training.
5
6#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9// ✅ SciRS2 Policy Compliant - Using scirs2_core for all random operations
10use scirs2_core::RngExt;
11
12use super::core::{rng_utils, Sampler, SamplerIterator};
13
14/// Importance sampling for biased data handling
15///
16/// This sampler adjusts sampling probabilities to correct for dataset bias
17/// or to emphasize certain types of samples during training. It allows for
18/// sophisticated control over the sampling distribution through importance weights.
19///
20/// # Examples
21///
22/// ```rust,ignore
23/// use torsh_data::sampler::{ImportanceSampler, Sampler};
24///
25/// // Create importance weights (higher = more important)
26/// let importance_weights = vec![0.1, 0.5, 1.0, 0.3, 0.8];
27/// let sampler = ImportanceSampler::new(importance_weights, 3, true)
28///     .with_temperature(0.5)
29///     .with_generator(42);
30///
31/// let indices: Vec<usize> = sampler.iter().collect();
32/// assert_eq!(indices.len(), 3);
33/// ```
34#[derive(Debug, Clone)]
35pub struct ImportanceSampler {
36    importance_weights: Vec<f64>,
37    num_samples: usize,
38    replacement: bool,
39    temperature: f64,
40    generator: Option<u64>,
41}
42
43impl ImportanceSampler {
44    /// Create a new importance sampler
45    ///
46    /// # Arguments
47    ///
48    /// * `importance_weights` - Importance weights for each sample (higher = more important)
49    /// * `num_samples` - Number of samples to select
50    /// * `replacement` - Whether to sample with replacement
51    ///
52    /// # Panics
53    ///
54    /// Panics if importance_weights is empty or contains negative values.
55    ///
56    /// # Examples
57    ///
58    /// ```rust,ignore
59    /// use torsh_data::sampler::ImportanceSampler;
60    ///
61    /// let weights = vec![1.0, 2.0, 0.5, 3.0]; // Index 3 is most important
62    /// let sampler = ImportanceSampler::new(weights, 2, false);
63    /// ```
64    pub fn new(importance_weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
65        // Validate importance weights
66        assert!(
67            !importance_weights.is_empty() || num_samples == 0,
68            "importance_weights cannot be empty when num_samples > 0"
69        );
70        assert!(
71            importance_weights.iter().all(|&w| w >= 0.0),
72            "importance_weights must be non-negative"
73        );
74
75        // Skip weight sum validation for empty weights (when num_samples is 0)
76        if !importance_weights.is_empty() {
77            let weight_sum: f64 = importance_weights.iter().sum();
78            assert!(
79                weight_sum > 0.0 && weight_sum.is_finite(),
80                "importance_weights must sum to a positive finite value"
81            );
82        }
83
84        // Clamp num_samples to maximum available when sampling without replacement
85        let clamped_num_samples = if !replacement {
86            num_samples.min(importance_weights.len())
87        } else {
88            num_samples
89        };
90
91        Self {
92            importance_weights,
93            num_samples: clamped_num_samples,
94            replacement,
95            temperature: 1.0,
96            generator: None,
97        }
98    }
99
100    /// Set temperature for softmax scaling of importance weights
101    ///
102    /// Temperature controls the sharpness of the importance distribution:
103    /// - Higher temperature (> 1.0) = more uniform sampling
104    /// - Lower temperature (< 1.0) = more biased toward high importance samples
105    /// - Temperature = 1.0 = no scaling
106    ///
107    /// # Arguments
108    ///
109    /// * `temperature` - Temperature value (must be positive)
110    ///
111    /// # Examples
112    ///
113    /// ```rust,ignore
114    /// use torsh_data::sampler::ImportanceSampler;
115    ///
116    /// let weights = vec![1.0, 2.0, 3.0];
117    /// let sampler = ImportanceSampler::new(weights, 2, true)
118    ///     .with_temperature(0.5); // More emphasis on high importance
119    /// ```
120    pub fn with_temperature(mut self, temperature: f64) -> Self {
121        assert!(temperature > 0.0, "temperature must be positive");
122        self.temperature = temperature;
123        self
124    }
125
126    /// Set random generator seed
127    ///
128    /// # Arguments
129    ///
130    /// * `seed` - Random seed for reproducible sampling
131    pub fn with_generator(mut self, seed: u64) -> Self {
132        self.generator = Some(seed);
133        self
134    }
135
136    /// Get the importance weights
137    pub fn importance_weights(&self) -> &[f64] {
138        &self.importance_weights
139    }
140
141    /// Get the number of samples
142    pub fn num_samples(&self) -> usize {
143        self.num_samples
144    }
145
146    /// Check if sampling with replacement
147    pub fn replacement(&self) -> bool {
148        self.replacement
149    }
150
151    /// Get the temperature
152    pub fn temperature(&self) -> f64 {
153        self.temperature
154    }
155
156    /// Get the generator seed if set
157    pub fn generator(&self) -> Option<u64> {
158        self.generator
159    }
160
161    /// Update importance weights
162    ///
163    /// # Arguments
164    ///
165    /// * `new_weights` - New importance weights
166    ///
167    /// # Panics
168    ///
169    /// Panics if new_weights has different length than original weights.
170    pub fn update_weights(&mut self, new_weights: Vec<f64>) {
171        assert_eq!(
172            new_weights.len(),
173            self.importance_weights.len(),
174            "New weights must have same length as original weights"
175        );
176        assert!(
177            new_weights.iter().all(|&w| w >= 0.0),
178            "importance_weights must be non-negative"
179        );
180
181        let weight_sum: f64 = new_weights.iter().sum();
182        assert!(
183            weight_sum > 0.0 && weight_sum.is_finite(),
184            "importance_weights must sum to a positive finite value"
185        );
186
187        self.importance_weights = new_weights;
188    }
189
190    /// Apply temperature scaling to importance weights
191    fn get_scaled_weights(&self) -> Vec<f64> {
192        if (self.temperature - 1.0).abs() < f64::EPSILON {
193            return self.importance_weights.clone();
194        }
195
196        // Apply temperature scaling (like softmax)
197        let max_weight = self
198            .importance_weights
199            .iter()
200            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
201
202        self.importance_weights
203            .iter()
204            .map(|&w| ((w - max_weight) / self.temperature).exp())
205            .collect()
206    }
207
208    /// Sample indices with importance weighting and replacement
209    fn sample_with_replacement(&self) -> Vec<usize> {
210        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
211        let mut rng = rng_utils::create_rng(self.generator);
212
213        let scaled_weights = self.get_scaled_weights();
214        let weight_sum: f64 = scaled_weights.iter().sum();
215
216        // Create cumulative distribution
217        let mut cumulative_weights = Vec::with_capacity(scaled_weights.len());
218        let mut cumsum = 0.0;
219
220        for &weight in &scaled_weights {
221            cumsum += weight / weight_sum;
222            cumulative_weights.push(cumsum);
223        }
224
225        // Ensure the last value is exactly 1.0
226        if let Some(last) = cumulative_weights.last_mut() {
227            *last = 1.0;
228        }
229
230        // Sample using inverse transform sampling
231        (0..self.num_samples)
232            .map(|_| {
233                let rand_val: f64 = rng.random();
234                cumulative_weights
235                    .binary_search_by(|&x| {
236                        x.partial_cmp(&rand_val)
237                            .unwrap_or(std::cmp::Ordering::Equal)
238                    })
239                    .unwrap_or_else(|i| i)
240                    .min(self.importance_weights.len() - 1)
241            })
242            .collect()
243    }
244
245    /// Sample indices with importance weighting without replacement
246    fn sample_without_replacement(&self) -> Vec<usize> {
247        if self.num_samples >= self.importance_weights.len() {
248            // Return all indices if we need more samples than available
249            return (0..self.importance_weights.len()).collect();
250        }
251
252        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
253        let mut rng = rng_utils::create_rng(self.generator);
254
255        let scaled_weights = self.get_scaled_weights();
256        let mut selected_indices = Vec::new();
257        let mut remaining_weights = scaled_weights;
258        let mut remaining_indices: Vec<usize> = (0..self.importance_weights.len()).collect();
259
260        for _ in 0..self.num_samples {
261            if remaining_indices.is_empty() {
262                break;
263            }
264
265            // Normalize remaining weights
266            let weight_sum: f64 = remaining_weights.iter().sum();
267            if weight_sum <= 0.0 {
268                break;
269            }
270
271            let mut cumsum = 0.0;
272            let rand_val: f64 = rng.random::<f64>() * weight_sum;
273
274            let mut selected_idx = 0;
275            for (i, &weight) in remaining_weights.iter().enumerate() {
276                cumsum += weight;
277                if cumsum >= rand_val {
278                    selected_idx = i;
279                    break;
280                }
281            }
282
283            // Add the selected index to results
284            selected_indices.push(remaining_indices[selected_idx]);
285
286            // Remove the selected index and weight
287            remaining_indices.remove(selected_idx);
288            remaining_weights.remove(selected_idx);
289        }
290
291        selected_indices
292    }
293
294    /// Get sampling statistics
295    pub fn sampling_stats(&self) -> ImportanceStats {
296        let scaled_weights = self.get_scaled_weights();
297        let weight_sum: f64 = scaled_weights.iter().sum();
298        let mean_weight = weight_sum / scaled_weights.len() as f64;
299
300        let variance = scaled_weights
301            .iter()
302            .map(|&w| (w - mean_weight).powi(2))
303            .sum::<f64>()
304            / scaled_weights.len() as f64;
305
306        let max_weight = scaled_weights
307            .iter()
308            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
309        let min_weight = scaled_weights.iter().fold(f64::INFINITY, |a, &b| a.min(b));
310
311        ImportanceStats {
312            num_samples: self.num_samples,
313            total_items: self.importance_weights.len(),
314            replacement: self.replacement,
315            temperature: self.temperature,
316            mean_weight,
317            weight_variance: variance,
318            weight_range: max_weight - min_weight,
319            weight_ratio: if min_weight > 0.0 {
320                max_weight / min_weight
321            } else {
322                f64::INFINITY
323            },
324        }
325    }
326}
327
328impl Sampler for ImportanceSampler {
329    type Iter = SamplerIterator;
330
331    fn iter(&self) -> Self::Iter {
332        let indices = if self.replacement {
333            self.sample_with_replacement()
334        } else {
335            self.sample_without_replacement()
336        };
337
338        SamplerIterator::new(indices)
339    }
340
341    fn len(&self) -> usize {
342        if self.replacement {
343            self.num_samples
344        } else {
345            self.num_samples.min(self.importance_weights.len())
346        }
347    }
348}
349
350/// Statistics about importance sampling
351#[derive(Debug, Clone, PartialEq)]
352pub struct ImportanceStats {
353    /// Number of samples to be drawn
354    pub num_samples: usize,
355    /// Total number of items in the dataset
356    pub total_items: usize,
357    /// Whether sampling with replacement
358    pub replacement: bool,
359    /// Temperature scaling factor
360    pub temperature: f64,
361    /// Mean importance weight (after scaling)
362    pub mean_weight: f64,
363    /// Variance in importance weights
364    pub weight_variance: f64,
365    /// Range of importance weights (max - min)
366    pub weight_range: f64,
367    /// Ratio of max to min importance weights
368    pub weight_ratio: f64,
369}
370
371/// Create an importance sampler with uniform weights
372///
373/// Convenience function for creating an importance sampler with equal weights
374/// for all samples (equivalent to uniform sampling).
375///
376/// # Arguments
377///
378/// * `dataset_size` - Size of the dataset
379/// * `num_samples` - Number of samples to select
380/// * `replacement` - Whether to sample with replacement
381/// * `seed` - Optional random seed for reproducible sampling
382pub fn uniform_importance_sampler(
383    dataset_size: usize,
384    num_samples: usize,
385    replacement: bool,
386    seed: Option<u64>,
387) -> ImportanceSampler {
388    let weights = vec![1.0; dataset_size];
389    let mut sampler = ImportanceSampler::new(weights, num_samples, replacement);
390    if let Some(s) = seed {
391        sampler = sampler.with_generator(s);
392    }
393    sampler
394}
395
396/// Create an importance sampler from class frequencies for class balancing
397///
398/// Creates importance weights that are inversely proportional to class frequencies,
399/// helping to balance training for imbalanced datasets.
400///
401/// # Arguments
402///
403/// * `class_labels` - Class label for each sample
404/// * `num_samples` - Number of samples to select
405/// * `replacement` - Whether to sample with replacement
406/// * `seed` - Optional random seed for reproducible sampling
407pub fn class_balanced_importance_sampler(
408    class_labels: &[usize],
409    num_samples: usize,
410    replacement: bool,
411    seed: Option<u64>,
412) -> ImportanceSampler {
413    // Count frequency of each class
414    let max_class = class_labels.iter().max().copied().unwrap_or(0);
415    let mut class_counts = vec![0usize; max_class + 1];
416
417    for &label in class_labels {
418        if label <= max_class {
419            class_counts[label] += 1;
420        }
421    }
422
423    // Calculate inverse frequency weights
424    let total_samples = class_labels.len() as f64;
425    let num_classes = class_counts.len() as f64;
426
427    let weights: Vec<f64> = class_labels
428        .iter()
429        .map(|&label| {
430            let class_count = class_counts[label];
431            if class_count > 0 {
432                total_samples / (num_classes * class_count as f64)
433            } else {
434                1.0
435            }
436        })
437        .collect();
438
439    let mut sampler = ImportanceSampler::new(weights, num_samples, replacement);
440    if let Some(s) = seed {
441        sampler = sampler.with_generator(s);
442    }
443    sampler
444}
445
446/// Create an importance sampler based on loss values
447///
448/// Creates importance weights based on training losses, emphasizing
449/// samples with higher losses (harder samples).
450///
451/// # Arguments
452///
453/// * `losses` - Loss value for each sample
454/// * `num_samples` - Number of samples to select
455/// * `replacement` - Whether to sample with replacement
456/// * `power` - Power to raise losses to (higher = more emphasis on hard samples)
457/// * `seed` - Optional random seed for reproducible sampling
458pub fn loss_based_importance_sampler(
459    losses: &[f64],
460    num_samples: usize,
461    replacement: bool,
462    power: f64,
463    seed: Option<u64>,
464) -> ImportanceSampler {
465    let weights: Vec<f64> = losses
466        .iter()
467        .map(|&loss| loss.max(1e-6).powf(power)) // Avoid zero weights
468        .collect();
469
470    let mut sampler = ImportanceSampler::new(weights, num_samples, replacement);
471    if let Some(s) = seed {
472        sampler = sampler.with_generator(s);
473    }
474    sampler
475}
476
477/// Create an importance sampler with exponential weighting
478///
479/// Creates importance weights using exponential scaling, which can provide
480/// smoother transitions between importance levels.
481///
482/// # Arguments
483///
484/// * `scores` - Importance scores for each sample
485/// * `num_samples` - Number of samples to select
486/// * `replacement` - Whether to sample with replacement
487/// * `scale` - Exponential scaling factor
488/// * `seed` - Optional random seed for reproducible sampling
489pub fn exponential_importance_sampler(
490    scores: &[f64],
491    num_samples: usize,
492    replacement: bool,
493    scale: f64,
494    seed: Option<u64>,
495) -> ImportanceSampler {
496    let weights: Vec<f64> = scores.iter().map(|&score| (score * scale).exp()).collect();
497
498    let mut sampler = ImportanceSampler::new(weights, num_samples, replacement);
499    if let Some(s) = seed {
500        sampler = sampler.with_generator(s);
501    }
502    sampler
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_importance_sampler_basic() {
511        let importance_weights = vec![0.1, 0.5, 1.0, 0.3, 0.8];
512        let sampler = ImportanceSampler::new(importance_weights.clone(), 3, true)
513            .with_temperature(1.0)
514            .with_generator(42);
515
516        assert_eq!(sampler.importance_weights(), &importance_weights);
517        assert_eq!(sampler.num_samples(), 3);
518        assert!(sampler.replacement());
519        assert_eq!(sampler.temperature(), 1.0);
520        assert_eq!(sampler.generator(), Some(42));
521
522        let indices: Vec<usize> = sampler.iter().collect();
523        assert_eq!(indices.len(), 3);
524
525        // All indices should be valid
526        for &idx in &indices {
527            assert!(idx < 5);
528        }
529    }
530
531    #[test]
532    fn test_importance_sampler_without_replacement() {
533        let importance_weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
534        let sampler = ImportanceSampler::new(importance_weights, 3, false).with_generator(42);
535
536        assert!(!sampler.replacement());
537        assert_eq!(sampler.len(), 3);
538
539        let indices: Vec<usize> = sampler.iter().collect();
540        assert_eq!(indices.len(), 3);
541
542        // All indices should be unique
543        let mut sorted_indices = indices.clone();
544        sorted_indices.sort();
545        sorted_indices.dedup();
546        assert_eq!(sorted_indices.len(), 3);
547
548        // All indices should be valid
549        for &idx in &indices {
550            assert!(idx < 5);
551        }
552    }
553
554    #[test]
555    fn test_importance_sampler_temperature_scaling() {
556        let importance_weights = vec![1.0, 10.0]; // Very different weights
557
558        // Low temperature should emphasize differences
559        let low_temp_sampler = ImportanceSampler::new(importance_weights.clone(), 10, true)
560            .with_temperature(0.1)
561            .with_generator(42);
562
563        // High temperature should make more uniform
564        let high_temp_sampler = ImportanceSampler::new(importance_weights.clone(), 10, true)
565            .with_temperature(10.0)
566            .with_generator(42);
567
568        let low_temp_indices: Vec<usize> = low_temp_sampler.iter().collect();
569        let high_temp_indices: Vec<usize> = high_temp_sampler.iter().collect();
570
571        // Count occurrences of index 1 (higher weight)
572        let low_temp_high_weight_count = low_temp_indices.iter().filter(|&&i| i == 1).count();
573        let high_temp_high_weight_count = high_temp_indices.iter().filter(|&&i| i == 1).count();
574
575        // Low temperature should favor high weight index more
576        assert!(low_temp_high_weight_count >= high_temp_high_weight_count);
577    }
578
579    #[test]
580    fn test_importance_sampler_edge_cases() {
581        // Single sample
582        let single_weight = vec![1.0];
583        let single_sampler = ImportanceSampler::new(single_weight, 1, false);
584        let indices: Vec<usize> = single_sampler.iter().collect();
585        assert_eq!(indices, vec![0]);
586
587        // Zero samples
588        let zero_sampler = ImportanceSampler::new(vec![1.0, 2.0], 0, true);
589        assert_eq!(zero_sampler.len(), 0);
590        let indices: Vec<usize> = zero_sampler.iter().collect();
591        assert!(indices.is_empty());
592
593        // More samples than available (without replacement)
594        let limited_sampler = ImportanceSampler::new(vec![1.0, 2.0], 5, false);
595        assert_eq!(limited_sampler.len(), 2); // Should be clamped
596        let indices: Vec<usize> = limited_sampler.iter().collect();
597        assert_eq!(indices.len(), 2);
598    }
599
600    #[test]
601    fn test_importance_sampler_extreme_weights() {
602        // One weight much higher than others
603        let extreme_weights = vec![0.001, 0.001, 1000.0, 0.001];
604        let sampler = ImportanceSampler::new(extreme_weights, 20, true).with_generator(42);
605
606        let indices: Vec<usize> = sampler.iter().collect();
607        assert_eq!(indices.len(), 20);
608
609        // Should heavily favor index 2
610        let high_weight_count = indices.iter().filter(|&&i| i == 2).count();
611        assert!(high_weight_count > 10); // Should be most of the samples
612    }
613
614    #[test]
615    fn test_update_weights() {
616        let mut sampler = ImportanceSampler::new(vec![1.0, 2.0, 3.0], 2, true);
617
618        let new_weights = vec![3.0, 1.0, 2.0];
619        sampler.update_weights(new_weights.clone());
620
621        assert_eq!(sampler.importance_weights(), &new_weights);
622    }
623
624    #[test]
625    fn test_sampling_stats() {
626        let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
627        let sampler = ImportanceSampler::new(weights, 3, true);
628
629        let stats = sampler.sampling_stats();
630        assert_eq!(stats.num_samples, 3);
631        assert_eq!(stats.total_items, 5);
632        assert!(stats.replacement);
633        assert_eq!(stats.temperature, 1.0);
634        assert!(stats.mean_weight > 0.0);
635        assert!(stats.weight_variance >= 0.0);
636        assert!(stats.weight_range >= 0.0);
637        assert!(stats.weight_ratio >= 1.0);
638    }
639
640    #[test]
641    fn test_convenience_functions() {
642        // Test uniform_importance_sampler
643        let uniform = uniform_importance_sampler(10, 5, true, Some(42));
644        assert_eq!(uniform.importance_weights().len(), 10);
645        assert!(uniform.importance_weights().iter().all(|&w| w == 1.0));
646
647        // Test class_balanced_importance_sampler
648        let class_labels = vec![0, 0, 0, 1, 1, 2]; // Imbalanced: 3, 2, 1
649        let balanced = class_balanced_importance_sampler(&class_labels, 4, true, Some(42));
650        let weights = balanced.importance_weights();
651
652        // Class 2 (1 sample) should have highest weight
653        // Class 0 (3 samples) should have lowest weight
654        assert!(weights[5] > weights[3]); // Class 2 > Class 1
655        assert!(weights[3] > weights[0]); // Class 1 > Class 0
656
657        // Test loss_based_importance_sampler
658        let losses = vec![0.1, 0.8, 0.3, 0.9, 0.2];
659        let loss_based = loss_based_importance_sampler(&losses, 3, true, 1.0, Some(42));
660        let weights = loss_based.importance_weights();
661
662        // Higher loss should give higher weight
663        assert!(weights[3] > weights[2]); // Loss 0.9 > 0.3
664        assert!(weights[1] > weights[0]); // Loss 0.8 > 0.1
665
666        // Test exponential_importance_sampler
667        let scores = vec![1.0, 2.0, 3.0];
668        let exponential = exponential_importance_sampler(&scores, 2, true, 1.0, Some(42));
669        let weights = exponential.importance_weights();
670
671        // Should follow exponential relationship
672        assert!(weights[2] > weights[1]);
673        assert!(weights[1] > weights[0]);
674    }
675
676    #[test]
677    fn test_scaled_weights() {
678        let weights = vec![1.0, 2.0, 3.0];
679        let sampler = ImportanceSampler::new(weights.clone(), 2, true);
680
681        // Temperature = 1.0 should not change weights
682        let scaled_1 = sampler.get_scaled_weights();
683        assert_eq!(scaled_1, weights);
684
685        // Lower temperature should increase differences
686        let sampler_low = sampler.clone().with_temperature(0.5);
687        let scaled_low = sampler_low.get_scaled_weights();
688
689        // Higher temperature should decrease differences
690        let sampler_high = sampler.clone().with_temperature(2.0);
691        let scaled_high = sampler_high.get_scaled_weights();
692
693        // Check that scaling affects the distribution
694        assert_ne!(scaled_low, weights);
695        assert_ne!(scaled_high, weights);
696    }
697
698    #[test]
699    fn test_reproducibility() {
700        let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
701        let sampler1 = ImportanceSampler::new(weights.clone(), 3, true).with_generator(123);
702        let sampler2 = ImportanceSampler::new(weights, 3, true).with_generator(123);
703
704        let indices1: Vec<usize> = sampler1.iter().collect();
705        let indices2: Vec<usize> = sampler2.iter().collect();
706
707        assert_eq!(indices1, indices2);
708    }
709
710    #[test]
711    #[should_panic(expected = "importance_weights cannot be empty")]
712    fn test_empty_weights() {
713        ImportanceSampler::new(vec![], 5, true);
714    }
715
716    #[test]
717    #[should_panic(expected = "importance_weights must be non-negative")]
718    fn test_negative_weights() {
719        ImportanceSampler::new(vec![1.0, -1.0, 2.0], 3, true);
720    }
721
722    #[test]
723    #[should_panic(expected = "importance_weights must sum to a positive finite value")]
724    fn test_zero_sum_weights() {
725        ImportanceSampler::new(vec![0.0, 0.0, 0.0], 2, true);
726    }
727
728    #[test]
729    fn test_invalid_no_replacement() {
730        // Requesting more samples than available without replacement should clamp to available
731        let sampler = ImportanceSampler::new(vec![1.0, 2.0], 5, false);
732        assert_eq!(sampler.len(), 2); // Should be clamped to available items
733    }
734
735    #[test]
736    #[should_panic(expected = "temperature must be positive")]
737    fn test_invalid_temperature() {
738        ImportanceSampler::new(vec![1.0, 2.0], 1, true).with_temperature(0.0);
739    }
740
741    #[test]
742    #[should_panic(expected = "New weights must have same length")]
743    fn test_update_weights_wrong_size() {
744        let mut sampler = ImportanceSampler::new(vec![1.0, 2.0, 3.0], 2, true);
745        sampler.update_weights(vec![1.0, 2.0]); // Wrong size
746    }
747
748    #[test]
749    fn test_class_balanced_edge_cases() {
750        // Empty labels
751        let balanced_empty = class_balanced_importance_sampler(&[], 0, true, None);
752        assert!(balanced_empty.importance_weights().is_empty());
753
754        // Single class
755        let single_class = vec![0, 0, 0];
756        let balanced_single = class_balanced_importance_sampler(&single_class, 2, true, None);
757        let weights = balanced_single.importance_weights();
758        assert!(weights.iter().all(|&w| w > 0.0));
759        // All weights should be equal for single class
760        assert!((weights[0] - weights[1]).abs() < f64::EPSILON);
761
762        // Large class numbers
763        let large_classes = vec![0, 100, 5];
764        let balanced_large = class_balanced_importance_sampler(&large_classes, 2, true, None);
765        assert_eq!(balanced_large.importance_weights().len(), 3);
766    }
767
768    #[test]
769    fn test_loss_based_edge_cases() {
770        // Zero losses
771        let zero_losses = vec![0.0, 0.0, 0.0];
772        let loss_sampler = loss_based_importance_sampler(&zero_losses, 2, true, 1.0, None);
773        let weights = loss_sampler.importance_weights();
774        assert!(weights.iter().all(|&w| w > 0.0)); // Should have minimum weights
775
776        // Extreme losses
777        let extreme_losses = vec![1e-10, 1e10];
778        let extreme_sampler = loss_based_importance_sampler(&extreme_losses, 1, true, 1.0, None);
779        let weights = extreme_sampler.importance_weights();
780        assert!(weights[1] > weights[0]); // High loss should have higher weight
781    }
782
783    #[test]
784    fn test_exponential_edge_cases() {
785        // Negative scores
786        let negative_scores = vec![-1.0, 0.0, 1.0];
787        let exp_sampler = exponential_importance_sampler(&negative_scores, 2, true, 1.0, None);
788        let weights = exp_sampler.importance_weights();
789        assert!(weights.iter().all(|&w| w > 0.0)); // All weights should be positive
790        assert!(weights[2] > weights[1]); // Higher score should have higher weight
791        assert!(weights[1] > weights[0]); // exp(0) > exp(-1)
792
793        // Large scale factor
794        let scores = vec![1.0, 2.0];
795        let large_scale = exponential_importance_sampler(&scores, 1, true, 10.0, None);
796        let weights = large_scale.importance_weights();
797        assert!(weights[1] > weights[0]); // Should maintain order
798    }
799}