Skip to main content

torsh_data/sampler/
core.rs

1//! Core sampling traits and utilities
2//!
3//! This module provides the fundamental building blocks for all sampling strategies.
4
5#[cfg(not(feature = "std"))]
6use alloc::{boxed::Box, vec::Vec};
7
8// ✅ SciRS2 Policy Compliant - Using scirs2_core for all random operations
9use scirs2_core::random::Random;
10
11/// Common RNG utilities for samplers
12pub(crate) mod rng_utils {
13    use super::*;
14
15    /// Create a seeded or default RNG
16    pub fn create_rng(seed: Option<u64>) -> Random<scirs2_core::rngs::StdRng> {
17        if let Some(seed) = seed {
18            Random::seed(seed)
19        } else {
20            Random::seed(42) // Default seed for reproducible behavior
21        }
22    }
23
24    /// Generic shuffle function
25    pub fn shuffle_indices<T: Clone>(indices: &mut [T], seed: Option<u64>) {
26        let mut rng = create_rng(seed);
27
28        // Fisher-Yates shuffle
29        for i in (1..indices.len()).rev() {
30            let j = rng.gen_range(0..=i);
31            indices.swap(i, j);
32        }
33    }
34
35    /// Generate random range
36    pub fn gen_range(
37        rng: &mut Random<scirs2_core::rngs::StdRng>,
38        range: std::ops::Range<usize>,
39    ) -> usize {
40        rng.gen_range(range)
41    }
42}
43
44/// Unified trait for sampling from a dataset
45///
46/// This trait provides a consistent interface for both individual sample
47/// and batch sampling strategies.
48pub trait Sampler: Send {
49    /// Iterator type returned by the sampler
50    type Iter: Iterator<Item = usize> + Send;
51
52    /// Create an iterator over indices
53    fn iter(&self) -> Self::Iter;
54
55    /// Total number of samples that will be yielded
56    fn len(&self) -> usize;
57
58    /// Check if sampler is empty
59    fn is_empty(&self) -> bool {
60        self.len() == 0
61    }
62
63    /// Convert this sampler into a batch sampler
64    fn into_batch_sampler(
65        self,
66        batch_size: usize,
67        drop_last: bool,
68    ) -> super::batch::BatchingSampler<Self>
69    where
70        Self: Sized,
71    {
72        super::batch::BatchingSampler::new(self, batch_size, drop_last)
73    }
74
75    /// Create a distributed version of this sampler
76    fn into_distributed(
77        self,
78        num_replicas: usize,
79        rank: usize,
80    ) -> super::distributed::DistributedWrapper<Self>
81    where
82        Self: Sized,
83    {
84        super::distributed::DistributedWrapper::new(self, num_replicas, rank)
85    }
86}
87
88/// Trait for batch samplers that yield batches of indices
89pub trait BatchSampler: Send {
90    /// Iterator type returned by the batch sampler
91    type Iter: Iterator<Item = Vec<usize>> + Send;
92
93    /// Create an iterator over batches of indices
94    fn iter(&self) -> Self::Iter;
95
96    /// Total number of batches that will be yielded
97    fn num_batches(&self) -> usize;
98
99    /// Total number of batches that will be yielded (alias for num_batches)
100    fn len(&self) -> usize {
101        self.num_batches()
102    }
103
104    /// Check if batch sampler is empty
105    fn is_empty(&self) -> bool {
106        self.num_batches() == 0
107    }
108}
109
110/// Iterator wrapper that provides additional functionality
111pub struct SamplerIterator {
112    indices: Vec<usize>,
113    position: usize,
114}
115
116impl SamplerIterator {
117    /// Create a new sampler iterator
118    pub fn new(indices: Vec<usize>) -> Self {
119        Self {
120            indices,
121            position: 0,
122        }
123    }
124
125    /// Create from a range
126    pub fn from_range(start: usize, end: usize) -> Self {
127        Self::new((start..end).collect())
128    }
129
130    /// Create shuffled indices
131    pub fn shuffled(mut indices: Vec<usize>, seed: Option<u64>) -> Self {
132        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
133        let mut rng = match seed {
134            Some(s) => Random::seed(s),
135            None => Random::seed(42), // Use fixed seed instead of Random::new()
136        };
137
138        // Fisher-Yates shuffle
139        for i in (1..indices.len()).rev() {
140            let j = rng.gen_range(0..=i);
141            indices.swap(i, j);
142        }
143
144        Self::new(indices)
145    }
146
147    /// Get remaining items count
148    pub fn remaining(&self) -> usize {
149        self.indices.len() - self.position
150    }
151}
152
153impl Iterator for SamplerIterator {
154    type Item = usize;
155
156    fn next(&mut self) -> Option<Self::Item> {
157        if self.position < self.indices.len() {
158            let item = self.indices[self.position];
159            self.position += 1;
160            Some(item)
161        } else {
162            None
163        }
164    }
165
166    fn size_hint(&self) -> (usize, Option<usize>) {
167        let remaining = self.remaining();
168        (remaining, Some(remaining))
169    }
170}
171
172impl ExactSizeIterator for SamplerIterator {
173    fn len(&self) -> usize {
174        self.remaining()
175    }
176}
177
178/// Utility functions for sampling operations
179pub mod utils {
180    use super::*;
181
182    /// Generate random indices without replacement
183    pub fn random_indices(n: usize, k: usize, seed: Option<u64>) -> Vec<usize> {
184        assert!(k <= n, "Cannot sample more items than available");
185
186        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
187        let mut rng = match seed {
188            Some(s) => Random::seed(s),
189            None => Random::seed(42),
190        };
191
192        if k == n {
193            // Return all indices shuffled
194            let mut indices: Vec<usize> = (0..n).collect();
195            for i in (1..indices.len()).rev() {
196                let j = rng.gen_range(0..=i);
197                indices.swap(i, j);
198            }
199            indices
200        } else if k <= n / 2 {
201            // Use rejection sampling for small k
202            let mut selected = std::collections::HashSet::new();
203            while selected.len() < k {
204                let idx = rng.gen_range(0..n);
205                selected.insert(idx);
206            }
207            let mut result: Vec<usize> = selected.into_iter().collect();
208            result.sort_unstable(); // Ensure deterministic ordering
209            result
210        } else {
211            // Use exclusion method for large k
212            let mut excluded = std::collections::HashSet::new();
213            while excluded.len() < n - k {
214                let idx = rng.gen_range(0..n);
215                excluded.insert(idx);
216            }
217            let mut result: Vec<usize> = (0..n).filter(|&i| !excluded.contains(&i)).collect();
218            result.sort_unstable(); // Ensure deterministic ordering
219            result
220        }
221    }
222
223    /// Stratified split of indices
224    pub fn stratified_split(
225        indices: &[usize],
226        labels: &[usize],
227        test_ratio: f32,
228        seed: Option<u64>,
229    ) -> (Vec<usize>, Vec<usize>) {
230        use std::collections::HashMap;
231
232        // Group indices by label
233        let mut label_groups: HashMap<usize, Vec<usize>> = HashMap::new();
234        for &idx in indices {
235            if idx < labels.len() {
236                label_groups
237                    .entry(labels[idx])
238                    .or_insert_with(Vec::new)
239                    .push(idx);
240            }
241        }
242
243        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
244        let mut rng = match seed {
245            Some(s) => Random::seed(s),
246            None => Random::seed(42),
247        };
248
249        let mut train_indices = Vec::new();
250        let mut test_indices = Vec::new();
251
252        // Split each label group
253        for (_, mut group_indices) in label_groups {
254            // Shuffle group indices
255            for i in (1..group_indices.len()).rev() {
256                let j = rng.gen_range(0..=i);
257                group_indices.swap(i, j);
258            }
259
260            let test_size = ((group_indices.len() as f32) * test_ratio).round() as usize;
261            let test_size = test_size.min(group_indices.len());
262
263            test_indices.extend(group_indices.iter().take(test_size));
264            train_indices.extend(group_indices.iter().skip(test_size));
265        }
266
267        (train_indices, test_indices)
268    }
269
270    /// Calculate class weights for balanced sampling
271    pub fn calculate_class_weights(labels: &[usize], num_classes: usize) -> Vec<f32> {
272        let mut class_counts = vec![0usize; num_classes];
273
274        // Count occurrences
275        for &label in labels {
276            if label < num_classes {
277                class_counts[label] += 1;
278            }
279        }
280
281        // Calculate weights (inverse frequency)
282        let total_samples = labels.len() as f32;
283        class_counts
284            .iter()
285            .map(|&count| {
286                if count > 0 {
287                    total_samples / (num_classes as f32 * count as f32)
288                } else {
289                    0.0
290                }
291            })
292            .collect()
293    }
294
295    /// Validate sampling configuration
296    pub fn validate_sampling_params(
297        dataset_size: usize,
298        num_samples: Option<usize>,
299        replacement: bool,
300    ) -> Result<usize, String> {
301        let actual_num_samples = num_samples.unwrap_or(dataset_size);
302
303        // Allow empty datasets (dataset_size == 0) with 0 samples
304        if dataset_size == 0 {
305            if actual_num_samples == 0 {
306                return Ok(0);
307            } else {
308                return Err("Cannot sample from empty dataset".to_string());
309            }
310        }
311
312        if !replacement && actual_num_samples > dataset_size {
313            return Err(format!(
314                "Cannot sample {} items without replacement from dataset of size {}",
315                actual_num_samples, dataset_size
316            ));
317        }
318
319        if actual_num_samples == 0 && !replacement {
320            return Err(
321                "Number of samples cannot be zero for non-empty dataset without replacement"
322                    .to_string(),
323            );
324        }
325
326        Ok(actual_num_samples)
327    }
328
329    /// Simple train-validation split
330    pub fn train_val_split(
331        dataset_size: usize,
332        val_ratio: f32,
333        seed: Option<u64>,
334    ) -> (Vec<usize>, Vec<usize>) {
335        let val_size = (dataset_size as f32 * val_ratio).round() as usize;
336        let indices = random_indices(dataset_size, dataset_size, seed);
337
338        let (val_indices, train_indices) = indices.split_at(val_size);
339        (train_indices.to_vec(), val_indices.to_vec())
340    }
341
342    /// Generate k-fold cross-validation splits
343    pub fn kfold_splits(
344        dataset_size: usize,
345        k: usize,
346        seed: Option<u64>,
347    ) -> Vec<(Vec<usize>, Vec<usize>)> {
348        assert!(k > 1, "K must be greater than 1");
349        assert!(k <= dataset_size, "K cannot be larger than dataset size");
350
351        let indices = random_indices(dataset_size, dataset_size, seed);
352        let fold_size = dataset_size / k;
353        let mut splits = Vec::new();
354
355        for i in 0..k {
356            let start = i * fold_size;
357            let end = if i == k - 1 {
358                dataset_size // Last fold gets remaining samples
359            } else {
360                (i + 1) * fold_size
361            };
362
363            let val_indices = indices[start..end].to_vec();
364            let train_indices = [&indices[..start], &indices[end..]].concat();
365            splits.push((train_indices, val_indices));
366        }
367
368        splits
369    }
370
371    /// Three-way split: train, validation, test
372    pub fn train_val_test_split(
373        dataset_size: usize,
374        train_ratio: f32,
375        val_ratio: f32,
376        seed: Option<u64>,
377    ) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
378        assert!(
379            train_ratio + val_ratio < 1.0,
380            "Train and val ratios must sum to less than 1.0"
381        );
382        assert!(
383            train_ratio > 0.0 && val_ratio > 0.0,
384            "Ratios must be positive"
385        );
386
387        let train_size = (dataset_size as f32 * train_ratio).round() as usize;
388        let val_size = (dataset_size as f32 * val_ratio).round() as usize;
389        let _test_size = dataset_size - train_size - val_size;
390
391        let indices = random_indices(dataset_size, dataset_size, seed);
392
393        let train_indices = indices[..train_size].to_vec();
394        let val_indices = indices[train_size..train_size + val_size].to_vec();
395        let test_indices = indices[train_size + val_size..].to_vec();
396
397        (train_indices, val_indices, test_indices)
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_sampler_iterator_basic() {
407        let indices = vec![0, 1, 2, 3, 4];
408        let iter = SamplerIterator::new(indices.clone());
409
410        assert_eq!(iter.len(), 5);
411        assert_eq!(iter.remaining(), 5);
412
413        let collected: Vec<usize> = iter.collect();
414        assert_eq!(collected, indices);
415    }
416
417    #[test]
418    fn test_sampler_iterator_from_range() {
419        let iter = SamplerIterator::from_range(0, 5);
420        let collected: Vec<usize> = iter.collect();
421        assert_eq!(collected, vec![0, 1, 2, 3, 4]);
422    }
423
424    #[test]
425    fn test_sampler_iterator_shuffled() {
426        let indices = vec![0, 1, 2, 3, 4];
427        let iter = SamplerIterator::shuffled(indices.clone(), Some(42));
428        let collected: Vec<usize> = iter.collect();
429
430        // Should contain same elements but in different order
431        assert_eq!(collected.len(), indices.len());
432        for &idx in &indices {
433            assert!(collected.contains(&idx));
434        }
435    }
436
437    #[test]
438    fn test_utils_random_indices() {
439        let indices = utils::random_indices(10, 5, Some(42));
440        assert_eq!(indices.len(), 5);
441
442        // All indices should be unique and in range
443        let mut sorted_indices = indices.clone();
444        sorted_indices.sort();
445        sorted_indices.dedup();
446        assert_eq!(sorted_indices.len(), 5);
447
448        for &idx in &indices {
449            assert!(idx < 10);
450        }
451    }
452
453    #[test]
454    fn test_utils_random_indices_all() {
455        let indices = utils::random_indices(5, 5, Some(42));
456        assert_eq!(indices.len(), 5);
457
458        let mut sorted_indices = indices.clone();
459        sorted_indices.sort();
460        assert_eq!(sorted_indices, vec![0, 1, 2, 3, 4]);
461    }
462
463    #[test]
464    fn test_utils_stratified_split() {
465        let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
466        let labels = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
467
468        let (train, test) = utils::stratified_split(&indices, &labels, 0.3, Some(42));
469
470        // Check that we have roughly the right proportions
471        assert!(train.len() + test.len() == indices.len());
472        assert!(test.len() >= 2); // Should have at least some test samples
473
474        // Verify all indices are accounted for
475        let mut all_indices = train.clone();
476        all_indices.extend(test.clone());
477        all_indices.sort();
478        assert_eq!(all_indices, indices);
479    }
480
481    #[test]
482    fn test_utils_calculate_class_weights() {
483        let labels = vec![0, 0, 1, 1, 1, 2]; // Imbalanced: 2, 3, 1
484        let weights = utils::calculate_class_weights(&labels, 3);
485
486        assert_eq!(weights.len(), 3);
487
488        // Class 2 (1 sample) should have highest weight
489        // Class 1 (3 samples) should have lowest weight
490        assert!(weights[2] > weights[1]);
491        assert!(weights[0] > weights[1]);
492    }
493
494    #[test]
495    fn test_utils_validate_sampling_params() {
496        // Valid cases
497        assert!(utils::validate_sampling_params(10, Some(5), false).is_ok());
498        assert!(utils::validate_sampling_params(10, Some(15), true).is_ok());
499        assert!(utils::validate_sampling_params(10, None, false).is_ok());
500
501        // Empty dataset cases (now valid)
502        assert!(utils::validate_sampling_params(0, Some(0), false).is_ok());
503        assert!(utils::validate_sampling_params(0, None, false).is_ok());
504
505        // Zero samples with replacement (now valid)
506        assert!(utils::validate_sampling_params(10, Some(0), true).is_ok());
507
508        // Invalid cases
509        assert!(utils::validate_sampling_params(0, Some(5), false).is_err()); // Can't sample 5 from empty dataset
510        assert!(utils::validate_sampling_params(10, Some(0), false).is_err()); // Can't sample 0 from non-empty dataset without replacement
511        assert!(utils::validate_sampling_params(10, Some(15), false).is_err()); // Can't sample 15 without replacement from dataset of 10
512    }
513
514    #[test]
515    fn test_size_hints() {
516        let iter = SamplerIterator::new(vec![0, 1, 2]);
517        assert_eq!(iter.size_hint(), (3, Some(3)));
518
519        let mut iter = SamplerIterator::new(vec![0, 1, 2]);
520        iter.next();
521        assert_eq!(iter.size_hint(), (2, Some(2)));
522    }
523}