Skip to main content

torsh_data/sampler/
weighted.rs

1//! Weighted and subset sampling functionality
2//!
3//! This module provides sampling strategies that work with weighted datasets
4//! and subset selections, useful for imbalanced datasets and custom data selection.
5
6#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9// ✅ SciRS2 Policy Compliant - Using scirs2_core for all random operations
10use scirs2_core::RngExt;
11
12use super::core::{rng_utils, Sampler, SamplerIterator};
13
14/// Weighted random sampler for imbalanced datasets
15///
16/// This sampler allows sampling indices according to specified weights,
17/// which is particularly useful for handling imbalanced datasets or
18/// implementing custom sampling strategies.
19///
20/// # Examples
21///
22/// ```rust,ignore
23/// use torsh_data::sampler::{WeightedRandomSampler, Sampler};
24///
25/// // Higher weight for the last class (imbalanced dataset)
26/// let weights = vec![0.1, 0.1, 0.1, 0.1, 0.6];
27/// let sampler = WeightedRandomSampler::new(weights, 100, true).with_generator(42);
28///
29/// let indices: Vec<usize> = sampler.iter().collect();
30/// assert_eq!(indices.len(), 100);
31/// ```
32#[derive(Debug, Clone)]
33pub struct WeightedRandomSampler {
34    weights: Vec<f64>,
35    num_samples: usize,
36    replacement: bool,
37    generator: Option<u64>,
38}
39
40impl WeightedRandomSampler {
41    /// Create a new weighted random sampler
42    ///
43    /// # Arguments
44    ///
45    /// * `weights` - Sampling weights for each index
46    /// * `num_samples` - Number of samples to generate
47    /// * `replacement` - Whether to sample with replacement
48    ///
49    /// # Panics
50    ///
51    /// Panics if weights are empty, contain negative values, or don't sum to a positive finite value
52    pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
53        assert!(!weights.is_empty(), "weights cannot be empty");
54        assert!(
55            weights.iter().all(|&w| w >= 0.0),
56            "weights must be non-negative"
57        );
58        let weight_sum: f64 = weights.iter().sum();
59        assert!(
60            weight_sum > 0.0 && weight_sum.is_finite(),
61            "weights must sum to a positive finite value, got {weight_sum}"
62        );
63
64        Self {
65            weights,
66            num_samples,
67            replacement,
68            generator: None,
69        }
70    }
71
72    /// Set random generator seed
73    ///
74    /// # Arguments
75    ///
76    /// * `seed` - Random seed for reproducible sampling
77    pub fn with_generator(mut self, seed: u64) -> Self {
78        self.generator = Some(seed);
79        self
80    }
81
82    /// Get the weights
83    pub fn weights(&self) -> &[f64] {
84        &self.weights
85    }
86
87    /// Get the number of samples
88    pub fn num_samples(&self) -> usize {
89        self.num_samples
90    }
91
92    /// Check if sampling with replacement
93    pub fn replacement(&self) -> bool {
94        self.replacement
95    }
96
97    /// Get the generator seed if set
98    pub fn generator(&self) -> Option<u64> {
99        self.generator
100    }
101
102    /// Sample indices according to weights with replacement
103    fn sample_with_replacement(&self) -> Vec<usize> {
104        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
105        let mut rng = rng_utils::create_rng(self.generator);
106
107        // Normalize weights to create cumulative distribution
108        let weight_sum: f64 = self.weights.iter().sum();
109        let mut cumulative_weights = Vec::with_capacity(self.weights.len());
110        let mut cumsum = 0.0;
111
112        for &weight in &self.weights {
113            cumsum += weight / weight_sum;
114            cumulative_weights.push(cumsum);
115        }
116
117        // Ensure the last value is exactly 1.0 to handle floating point precision
118        if let Some(last) = cumulative_weights.last_mut() {
119            *last = 1.0;
120        }
121
122        // Sample using inverse transform sampling
123        (0..self.num_samples)
124            .map(|_| {
125                let rand_val: f64 = rng.random();
126                // Binary search for the first cumulative weight >= rand_val
127                cumulative_weights
128                    .binary_search_by(|&x| {
129                        x.partial_cmp(&rand_val)
130                            .unwrap_or(std::cmp::Ordering::Equal)
131                    })
132                    .unwrap_or_else(|i| i)
133                    .min(self.weights.len() - 1)
134            })
135            .collect()
136    }
137
138    /// Sample indices according to weights without replacement
139    fn sample_without_replacement(&self) -> Vec<usize> {
140        if self.num_samples >= self.weights.len() {
141            // Return all indices if we need more samples than available
142            return (0..self.weights.len()).collect();
143        }
144
145        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
146        let mut rng = rng_utils::create_rng(self.generator);
147
148        // Use weighted reservoir sampling for sampling without replacement
149        let mut selected_indices = Vec::new();
150        let mut remaining_weights = self.weights.clone();
151        let mut remaining_indices: Vec<usize> = (0..self.weights.len()).collect();
152
153        for _ in 0..self.num_samples {
154            if remaining_indices.is_empty() {
155                break;
156            }
157
158            // Normalize remaining weights
159            let weight_sum: f64 = remaining_weights.iter().sum();
160            if weight_sum <= 0.0 {
161                break;
162            }
163
164            let mut cumsum = 0.0;
165            let rand_val: f64 = rng.random::<f64>() * weight_sum;
166
167            let mut selected_idx = 0;
168            for (i, &weight) in remaining_weights.iter().enumerate() {
169                cumsum += weight;
170                if cumsum >= rand_val {
171                    selected_idx = i;
172                    break;
173                }
174            }
175
176            // Add the selected index to results
177            selected_indices.push(remaining_indices[selected_idx]);
178
179            // Remove the selected index and weight
180            remaining_indices.remove(selected_idx);
181            remaining_weights.remove(selected_idx);
182        }
183
184        selected_indices
185    }
186}
187
188impl Sampler for WeightedRandomSampler {
189    type Iter = SamplerIterator;
190
191    fn iter(&self) -> Self::Iter {
192        let indices = if self.replacement {
193            self.sample_with_replacement()
194        } else {
195            self.sample_without_replacement()
196        };
197
198        SamplerIterator::new(indices)
199    }
200
201    fn len(&self) -> usize {
202        self.num_samples
203    }
204}
205
206/// Random sampler that samples from a subset of indices
207///
208/// This sampler takes a predefined subset of indices and samples from them
209/// in random order. Useful for creating custom data splits or working with
210/// filtered datasets.
211///
212/// # Examples
213///
214/// ```rust,ignore
215/// use torsh_data::sampler::{SubsetRandomSampler, Sampler};
216///
217/// // Sample from odd indices only
218/// let subset_indices = vec![1, 3, 5, 7, 9];
219/// let sampler = SubsetRandomSampler::new(subset_indices).with_generator(42);
220///
221/// let indices: Vec<usize> = sampler.iter().collect();
222/// assert_eq!(indices.len(), 5);
223/// ```
224#[derive(Debug, Clone)]
225pub struct SubsetRandomSampler {
226    indices: Vec<usize>,
227    generator: Option<u64>,
228}
229
230impl SubsetRandomSampler {
231    /// Create a new subset random sampler
232    ///
233    /// # Arguments
234    ///
235    /// * `indices` - The subset of indices to sample from
236    pub fn new(indices: Vec<usize>) -> Self {
237        Self {
238            indices,
239            generator: None,
240        }
241    }
242
243    /// Set random generator seed
244    ///
245    /// # Arguments
246    ///
247    /// * `seed` - Random seed for reproducible sampling
248    pub fn with_generator(mut self, seed: u64) -> Self {
249        self.generator = Some(seed);
250        self
251    }
252
253    /// Get the subset indices
254    pub fn indices(&self) -> &[usize] {
255        &self.indices
256    }
257
258    /// Get the generator seed if set
259    pub fn generator(&self) -> Option<u64> {
260        self.generator
261    }
262}
263
264impl Sampler for SubsetRandomSampler {
265    type Iter = SamplerIterator;
266
267    fn iter(&self) -> Self::Iter {
268        let mut shuffled_indices = self.indices.clone();
269        rng_utils::shuffle_indices(&mut shuffled_indices, self.generator);
270        SamplerIterator::new(shuffled_indices)
271    }
272
273    fn len(&self) -> usize {
274        self.indices.len()
275    }
276}
277
278/// Create a weighted random sampler
279///
280/// Convenience function for creating a weighted random sampler.
281///
282/// # Arguments
283///
284/// * `weights` - Sampling weights for each index
285/// * `num_samples` - Number of samples to generate
286/// * `replacement` - Whether to sample with replacement
287/// * `seed` - Optional random seed for reproducible sampling
288pub fn weighted_random(
289    weights: Vec<f64>,
290    num_samples: usize,
291    replacement: bool,
292    seed: Option<u64>,
293) -> WeightedRandomSampler {
294    let mut sampler = WeightedRandomSampler::new(weights, num_samples, replacement);
295    if let Some(s) = seed {
296        sampler = sampler.with_generator(s);
297    }
298    sampler
299}
300
301/// Create a subset random sampler
302///
303/// Convenience function for creating a subset random sampler.
304///
305/// # Arguments
306///
307/// * `indices` - The subset of indices to sample from
308/// * `seed` - Optional random seed for reproducible sampling
309pub fn subset_random(indices: Vec<usize>, seed: Option<u64>) -> SubsetRandomSampler {
310    let mut sampler = SubsetRandomSampler::new(indices);
311    if let Some(s) = seed {
312        sampler = sampler.with_generator(s);
313    }
314    sampler
315}
316
317/// Create a balanced weighted sampler for class imbalance
318///
319/// Creates weights that are inversely proportional to class frequencies,
320/// providing balanced sampling for imbalanced datasets.
321///
322/// # Arguments
323///
324/// * `class_counts` - Number of samples per class
325/// * `num_samples` - Total number of samples to generate
326/// * `seed` - Optional random seed for reproducible sampling
327pub fn balanced_weighted(
328    class_counts: &[usize],
329    num_samples: usize,
330    seed: Option<u64>,
331) -> WeightedRandomSampler {
332    let total_samples: usize = class_counts.iter().sum();
333    let num_classes = class_counts.len();
334
335    // Calculate inverse frequency weights
336    let weights: Vec<f64> = class_counts
337        .iter()
338        .map(|&count| {
339            if count > 0 {
340                total_samples as f64 / (num_classes as f64 * count as f64)
341            } else {
342                0.0
343            }
344        })
345        .collect();
346
347    weighted_random(weights, num_samples, true, seed)
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_weighted_sampler_basic() {
356        let weights = vec![0.1, 0.1, 0.1, 0.1, 0.6]; // Last element has higher weight
357        let sampler = WeightedRandomSampler::new(weights.clone(), 100, true).with_generator(42);
358
359        assert_eq!(sampler.len(), 100);
360        assert_eq!(sampler.weights(), &weights);
361        assert_eq!(sampler.num_samples(), 100);
362        assert!(sampler.replacement());
363        assert_eq!(sampler.generator(), Some(42));
364
365        let indices: Vec<usize> = sampler.iter().collect();
366        assert_eq!(indices.len(), 100);
367
368        // All indices should be in valid range
369        for &idx in &indices {
370            assert!(idx < 5);
371        }
372    }
373
374    #[test]
375    fn test_weighted_sampler_without_replacement() {
376        let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
377        let sampler = WeightedRandomSampler::new(weights, 3, false).with_generator(42);
378
379        assert!(!sampler.replacement());
380
381        let indices: Vec<usize> = sampler.iter().collect();
382        assert_eq!(indices.len(), 3);
383
384        // All indices should be unique
385        let mut sorted_indices = indices.clone();
386        sorted_indices.sort();
387        sorted_indices.dedup();
388        assert_eq!(sorted_indices.len(), 3);
389
390        // All indices should be in valid range
391        for &idx in &indices {
392            assert!(idx < 5);
393        }
394    }
395
396    #[test]
397    fn test_weighted_sampler_uniform_weights() {
398        let weights = vec![1.0; 10];
399        let sampler = WeightedRandomSampler::new(weights, 50, true).with_generator(42);
400
401        let indices: Vec<usize> = sampler.iter().collect();
402        assert_eq!(indices.len(), 50);
403
404        // With uniform weights, should get reasonably balanced distribution
405        let mut counts = [0; 10];
406        for &idx in &indices {
407            counts[idx] += 1;
408        }
409
410        // Each index should appear at least once in 50 samples
411        for count in counts {
412            assert!(count > 0);
413        }
414    }
415
416    #[test]
417    fn test_weighted_sampler_extreme_weights() {
418        let weights = vec![0.0, 0.0, 0.0, 1.0]; // Only last index has weight
419        let sampler = WeightedRandomSampler::new(weights, 10, true).with_generator(42);
420
421        let indices: Vec<usize> = sampler.iter().collect();
422        assert_eq!(indices.len(), 10);
423
424        // All samples should be from index 3
425        for &idx in &indices {
426            assert_eq!(idx, 3);
427        }
428    }
429
430    #[test]
431    #[should_panic(expected = "weights cannot be empty")]
432    fn test_weighted_sampler_empty_weights() {
433        WeightedRandomSampler::new(vec![], 10, true);
434    }
435
436    #[test]
437    #[should_panic(expected = "weights must be non-negative")]
438    fn test_weighted_sampler_negative_weights() {
439        WeightedRandomSampler::new(vec![1.0, -1.0, 1.0], 10, true);
440    }
441
442    #[test]
443    #[should_panic(expected = "weights must sum to a positive finite value")]
444    fn test_weighted_sampler_zero_sum() {
445        WeightedRandomSampler::new(vec![0.0, 0.0, 0.0], 10, true);
446    }
447
448    #[test]
449    fn test_subset_random_sampler() {
450        // Test with a subset of indices
451        let subset_indices = vec![1, 3, 5, 7, 9];
452        let sampler = SubsetRandomSampler::new(subset_indices.clone()).with_generator(42);
453
454        assert_eq!(sampler.len(), 5);
455        assert_eq!(sampler.indices(), &subset_indices);
456        assert_eq!(sampler.generator(), Some(42));
457
458        let sampled_indices: Vec<usize> = sampler.iter().collect();
459        assert_eq!(sampled_indices.len(), 5);
460
461        // Check that all sampled indices are from the original subset
462        for idx in &sampled_indices {
463            assert!(subset_indices.contains(idx));
464        }
465
466        // Check that we have all indices (just shuffled)
467        let mut sorted_sampled = sampled_indices.clone();
468        sorted_sampled.sort();
469        let mut sorted_original = subset_indices;
470        sorted_original.sort();
471        assert_eq!(sorted_sampled, sorted_original);
472    }
473
474    #[test]
475    fn test_subset_random_sampler_empty() {
476        let sampler = SubsetRandomSampler::new(vec![]);
477        assert_eq!(sampler.len(), 0);
478        assert!(sampler.is_empty());
479
480        let indices: Vec<usize> = sampler.iter().collect();
481        assert!(indices.is_empty());
482    }
483
484    #[test]
485    fn test_subset_random_sampler_single() {
486        let sampler = SubsetRandomSampler::new(vec![42]);
487        assert_eq!(sampler.len(), 1);
488
489        let indices: Vec<usize> = sampler.iter().collect();
490        assert_eq!(indices, vec![42]);
491    }
492
493    #[test]
494    fn test_subset_random_sampler_reproducible() {
495        let subset_indices = vec![10, 20, 30, 40, 50];
496        let sampler1 = SubsetRandomSampler::new(subset_indices.clone()).with_generator(123);
497        let sampler2 = SubsetRandomSampler::new(subset_indices).with_generator(123);
498
499        let indices1: Vec<usize> = sampler1.iter().collect();
500        let indices2: Vec<usize> = sampler2.iter().collect();
501
502        assert_eq!(indices1, indices2);
503    }
504
505    #[test]
506    fn test_convenience_functions() {
507        // Test weighted_random convenience function
508        let weights = vec![1.0, 2.0, 3.0];
509        let weighted = weighted_random(weights.clone(), 10, true, Some(42));
510        assert_eq!(weighted.weights(), &weights);
511        assert_eq!(weighted.num_samples(), 10);
512        assert!(weighted.replacement());
513        assert_eq!(weighted.generator(), Some(42));
514
515        // Test subset_random convenience function
516        let indices = vec![1, 3, 5];
517        let subset = subset_random(indices.clone(), Some(42));
518        assert_eq!(subset.indices(), &indices);
519        assert_eq!(subset.generator(), Some(42));
520
521        // Test balanced_weighted convenience function
522        let class_counts = vec![100, 50, 25]; // Imbalanced classes
523        let balanced = balanced_weighted(&class_counts, 30, Some(42));
524        assert_eq!(balanced.num_samples(), 30);
525        assert!(balanced.replacement());
526
527        // Verify that balanced weights are inversely proportional to class counts
528        let weights = balanced.weights();
529        assert!(weights[2] > weights[1]); // Smallest class has highest weight
530        assert!(weights[1] > weights[0]); // Medium class has medium weight
531    }
532
533    #[test]
534    fn test_balanced_weighted_edge_cases() {
535        // Test with zero counts
536        let class_counts = vec![100, 0, 50];
537        let balanced = balanced_weighted(&class_counts, 20, Some(42));
538        let weights = balanced.weights();
539
540        assert!(weights[0] > 0.0);
541        assert_eq!(weights[1], 0.0); // Zero count should give zero weight
542        assert!(weights[2] > 0.0);
543        assert!(weights[2] > weights[0]); // Smaller class should have higher weight
544
545        // Test with single class
546        let class_counts = vec![100];
547        let balanced = balanced_weighted(&class_counts, 10, Some(42));
548        assert_eq!(balanced.weights().len(), 1);
549        assert!(balanced.weights()[0] > 0.0);
550    }
551
552    #[test]
553    fn test_weighted_sampler_clone() {
554        let weights = vec![1.0, 2.0, 3.0];
555        let sampler = WeightedRandomSampler::new(weights.clone(), 10, true).with_generator(42);
556        let cloned = sampler.clone();
557
558        assert_eq!(sampler.weights(), cloned.weights());
559        assert_eq!(sampler.num_samples(), cloned.num_samples());
560        assert_eq!(sampler.replacement(), cloned.replacement());
561        assert_eq!(sampler.generator(), cloned.generator());
562    }
563
564    #[test]
565    fn test_subset_sampler_clone() {
566        let indices = vec![1, 3, 5, 7];
567        let sampler = SubsetRandomSampler::new(indices.clone()).with_generator(42);
568        let cloned = sampler.clone();
569
570        assert_eq!(sampler.indices(), cloned.indices());
571        assert_eq!(sampler.generator(), cloned.generator());
572    }
573}