scirs2_datasets/utils/
sampling.rs

1//! Data sampling utilities for statistical analysis and machine learning
2//!
3//! This module provides various sampling strategies including random sampling,
4//! stratified sampling, and importance-weighted sampling. These functions are
5//! useful for creating representative subsets of datasets, bootstrap sampling,
6//! and handling imbalanced data distributions.
7
8use crate::error::{DatasetsError, Result};
9use scirs2_core::ndarray::Array1;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::prelude::*;
12use scirs2_core::random::rngs::StdRng;
13use scirs2_core::random::seq::SliceRandom;
14use scirs2_core::random::Uniform;
15use std::collections::HashMap;
16
17/// Performs random sampling with or without replacement
18///
19/// This function creates random samples from a dataset using either bootstrap
20/// sampling (with replacement) or standard random sampling (without replacement).
21///
22/// # Arguments
23///
24/// * `n_samples` - Total number of samples in the dataset
25/// * `sample_size` - Number of samples to draw
26/// * `replace` - Whether to sample with replacement (bootstrap)
27/// * `random_seed` - Optional random seed for reproducible sampling
28///
29/// # Returns
30///
31/// A vector of indices representing the sampled data points
32///
33/// # Examples
34///
35/// ```rust
36/// use scirs2_datasets::utils::random_sample;
37///
38/// // Sample 5 indices from 10 total samples without replacement
39/// let indices = random_sample(10, 5, false, Some(42)).unwrap();
40/// assert_eq!(indices.len(), 5);
41/// assert!(indices.iter().all(|&i| i < 10));
42///
43/// // Bootstrap sampling (with replacement)
44/// let bootstrap_indices = random_sample(10, 15, true, Some(42)).unwrap();
45/// assert_eq!(bootstrap_indices.len(), 15);
46/// ```
47#[allow(dead_code)]
48pub fn random_sample(
49    n_samples: usize,
50    sample_size: usize,
51    replace: bool,
52    random_seed: Option<u64>,
53) -> Result<Vec<usize>> {
54    if n_samples == 0 {
55        return Err(DatasetsError::InvalidFormat(
56            "Number of _samples must be > 0".to_string(),
57        ));
58    }
59
60    if sample_size == 0 {
61        return Err(DatasetsError::InvalidFormat(
62            "Sample _size must be > 0".to_string(),
63        ));
64    }
65
66    if !replace && sample_size > n_samples {
67        return Err(DatasetsError::InvalidFormat(format!(
68            "Cannot sample {sample_size} items from {n_samples} without replacement"
69        )));
70    }
71
72    let mut rng = match random_seed {
73        Some(_seed) => StdRng::seed_from_u64(_seed),
74        None => {
75            let mut r = thread_rng();
76            StdRng::seed_from_u64(r.next_u64())
77        }
78    };
79
80    let mut indices = Vec::with_capacity(sample_size);
81
82    if replace {
83        // Bootstrap sampling (with replacement)
84        for _ in 0..sample_size {
85            indices.push(rng.sample(Uniform::new(0, n_samples).unwrap()));
86        }
87    } else {
88        // Sampling without replacement
89        let mut available: Vec<usize> = (0..n_samples).collect();
90        available.shuffle(&mut rng);
91        indices.extend_from_slice(&available[0..sample_size]);
92    }
93
94    Ok(indices)
95}
96
97/// Performs stratified random sampling
98///
99/// Maintains the same class distribution in the sample as in the original dataset.
100/// This is particularly useful for classification tasks where you want to ensure
101/// that all classes are represented proportionally in your sample.
102///
103/// # Arguments
104///
105/// * `targets` - Target values for stratification
106/// * `sample_size` - Number of samples to draw
107/// * `random_seed` - Optional random seed for reproducible sampling
108///
109/// # Returns
110///
111/// A vector of indices representing the stratified sample
112///
113/// # Examples
114///
115/// ```rust
116/// use scirs2_core::ndarray::Array1;
117/// use scirs2_datasets::utils::stratified_sample;
118///
119/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
120/// let indices = stratified_sample(&targets, 6, Some(42)).unwrap();
121/// assert_eq!(indices.len(), 6);
122///
123/// // Check that the sample maintains class proportions
124/// let mut class_counts = std::collections::HashMap::new();
125/// for &idx in &indices {
126///     let class = targets[idx] as i32;
127///     *class_counts.entry(class).or_insert(0) += 1;
128/// }
129/// ```
130#[allow(dead_code)]
131pub fn stratified_sample(
132    targets: &Array1<f64>,
133    sample_size: usize,
134    random_seed: Option<u64>,
135) -> Result<Vec<usize>> {
136    if targets.is_empty() {
137        return Err(DatasetsError::InvalidFormat(
138            "Targets array cannot be empty".to_string(),
139        ));
140    }
141
142    if sample_size == 0 {
143        return Err(DatasetsError::InvalidFormat(
144            "Sample _size must be > 0".to_string(),
145        ));
146    }
147
148    if sample_size > targets.len() {
149        return Err(DatasetsError::InvalidFormat(format!(
150            "Cannot sample {} items from {} total samples",
151            sample_size,
152            targets.len()
153        )));
154    }
155
156    // Group indices by target class
157    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
158    for (i, &target) in targets.iter().enumerate() {
159        let class = target.round() as i64;
160        class_indices.entry(class).or_default().push(i);
161    }
162
163    let mut rng = match random_seed {
164        Some(_seed) => StdRng::seed_from_u64(_seed),
165        None => {
166            let mut r = thread_rng();
167            StdRng::seed_from_u64(r.next_u64())
168        }
169    };
170
171    let mut stratified_indices = Vec::new();
172    let n_classes = class_indices.len();
173    let base_samples_per_class = sample_size / n_classes;
174    let remainder = sample_size % n_classes;
175
176    let mut class_list: Vec<_> = class_indices.keys().cloned().collect();
177    class_list.sort();
178
179    for (i, &class) in class_list.iter().enumerate() {
180        let class_samples = class_indices.get(&class).unwrap();
181        let samples_for_this_class = if i < remainder {
182            base_samples_per_class + 1
183        } else {
184            base_samples_per_class
185        };
186
187        if samples_for_this_class > class_samples.len() {
188            return Err(DatasetsError::InvalidFormat(format!(
189                "Class {} has only {} samples but needs {} for stratified sampling",
190                class,
191                class_samples.len(),
192                samples_for_this_class
193            )));
194        }
195
196        // Sample from this class
197        let sampled_indices = random_sample(
198            class_samples.len(),
199            samples_for_this_class,
200            false,
201            Some(rng.next_u64()),
202        )?;
203
204        for &idx in &sampled_indices {
205            stratified_indices.push(class_samples[idx]);
206        }
207    }
208
209    stratified_indices.shuffle(&mut rng);
210    Ok(stratified_indices)
211}
212
213/// Performs importance sampling based on provided weights
214///
215/// Samples indices according to the provided probability weights. Higher weights
216/// increase the probability of selection. This is useful for adaptive sampling
217/// where some samples are more important than others for training.
218///
219/// # Arguments
220///
221/// * `weights` - Probability weights for each sample (must be non-negative)
222/// * `sample_size` - Number of samples to draw
223/// * `replace` - Whether to sample with replacement
224/// * `random_seed` - Optional random seed for reproducible sampling
225///
226/// # Returns
227///
228/// A vector of indices representing the importance-weighted sample
229///
230/// # Examples
231///
232/// ```rust
233/// use scirs2_core::ndarray::Array1;
234/// use scirs2_datasets::utils::importance_sample;
235///
236/// // Give higher weights to the last few samples
237/// let weights = Array1::from(vec![0.1, 0.1, 0.1, 0.8, 0.9, 1.0]);
238/// let indices = importance_sample(&weights, 3, false, Some(42)).unwrap();
239/// assert_eq!(indices.len(), 3);
240///
241/// // Higher weighted samples should be more likely to be selected
242/// let mut high_weight_count = 0;
243/// for &idx in &indices {
244///     if idx >= 3 { // Last three samples have higher weights
245///         high_weight_count += 1;
246///     }
247/// }
248/// // This should be true with high probability
249/// assert!(high_weight_count >= 1);
250/// ```
251#[allow(dead_code)]
252pub fn importance_sample(
253    weights: &Array1<f64>,
254    sample_size: usize,
255    replace: bool,
256    random_seed: Option<u64>,
257) -> Result<Vec<usize>> {
258    if weights.is_empty() {
259        return Err(DatasetsError::InvalidFormat(
260            "Weights array cannot be empty".to_string(),
261        ));
262    }
263
264    if sample_size == 0 {
265        return Err(DatasetsError::InvalidFormat(
266            "Sample _size must be > 0".to_string(),
267        ));
268    }
269
270    if !replace && sample_size > weights.len() {
271        return Err(DatasetsError::InvalidFormat(format!(
272            "Cannot sample {} items from {} without replacement",
273            sample_size,
274            weights.len()
275        )));
276    }
277
278    // Check for negative weights
279    for &weight in weights.iter() {
280        if weight < 0.0 {
281            return Err(DatasetsError::InvalidFormat(
282                "All weights must be non-negative".to_string(),
283            ));
284        }
285    }
286
287    let weight_sum: f64 = weights.sum();
288    if weight_sum <= 0.0 {
289        return Err(DatasetsError::InvalidFormat(
290            "Sum of weights must be positive".to_string(),
291        ));
292    }
293
294    let mut rng = match random_seed {
295        Some(_seed) => StdRng::seed_from_u64(_seed),
296        None => {
297            let mut r = thread_rng();
298            StdRng::seed_from_u64(r.next_u64())
299        }
300    };
301
302    let mut indices = Vec::with_capacity(sample_size);
303    let mut available_weights = weights.clone();
304    let mut available_indices: Vec<usize> = (0..weights.len()).collect();
305
306    for _ in 0..sample_size {
307        let current_sum = available_weights.sum();
308        if current_sum <= 0.0 {
309            break;
310        }
311
312        // Generate random number between 0 and current_sum
313        let random_value = rng.gen_range(0.0..current_sum);
314
315        // Find the index corresponding to this random value
316        let mut cumulative_weight = 0.0;
317        let mut selected_idx = 0;
318
319        for (i, &weight) in available_weights.iter().enumerate() {
320            cumulative_weight += weight;
321            if random_value <= cumulative_weight {
322                selected_idx = i;
323                break;
324            }
325        }
326
327        let original_idx = available_indices[selected_idx];
328        indices.push(original_idx);
329
330        if !replace {
331            // Remove the selected item for sampling without replacement
332            available_weights = Array1::from_iter(
333                available_weights
334                    .iter()
335                    .enumerate()
336                    .filter(|(i_, _)| *i_ != selected_idx)
337                    .map(|(_, &w)| w),
338            );
339            available_indices.remove(selected_idx);
340        }
341    }
342
343    Ok(indices)
344}
345
346/// Generate bootstrap samples from indices
347///
348/// This is a convenience function that generates bootstrap samples (sampling with
349/// replacement) which is commonly used for bootstrap confidence intervals and
350/// ensemble methods.
351///
352/// # Arguments
353///
354/// * `n_samples` - Total number of samples in the dataset
355/// * `n_bootstrap_samples` - Number of bootstrap samples to generate
356/// * `random_seed` - Optional random seed for reproducible sampling
357///
358/// # Returns
359///
360/// A vector of bootstrap sample indices
361///
362/// # Examples
363///
364/// ```rust
365/// use scirs2_datasets::utils::bootstrap_sample;
366///
367/// let bootstrap_indices = bootstrap_sample(100, 100, Some(42)).unwrap();
368/// assert_eq!(bootstrap_indices.len(), 100);
369///
370/// // Some indices should appear multiple times (with high probability)
371/// let mut unique_indices = bootstrap_indices.clone();
372/// unique_indices.sort();
373/// unique_indices.dedup();
374/// assert!(unique_indices.len() < bootstrap_indices.len());
375/// ```
376#[allow(dead_code)]
377pub fn bootstrap_sample(
378    n_samples: usize,
379    n_bootstrap_samples: usize,
380    random_seed: Option<u64>,
381) -> Result<Vec<usize>> {
382    random_sample(n_samples, n_bootstrap_samples, true, random_seed)
383}
384
385/// Generate multiple bootstrap samples
386///
387/// Creates multiple independent bootstrap samples, useful for ensemble methods
388/// like bagging or for computing bootstrap confidence intervals.
389///
390/// # Arguments
391///
392/// * `n_samples` - Total number of samples in the dataset
393/// * `sample_size` - Size of each bootstrap sample
394/// * `n_bootstrap_rounds` - Number of bootstrap samples to generate
395/// * `random_seed` - Optional random seed for reproducible sampling
396///
397/// # Returns
398///
399/// A vector of bootstrap sample vectors
400///
401/// # Examples
402///
403/// ```rust
404/// use scirs2_datasets::utils::multiple_bootstrap_samples;
405///
406/// let bootstrap_samples = multiple_bootstrap_samples(50, 50, 10, Some(42)).unwrap();
407/// assert_eq!(bootstrap_samples.len(), 10);
408/// assert!(bootstrap_samples.iter().all(|sample| sample.len() == 50));
409/// ```
410#[allow(dead_code)]
411pub fn multiple_bootstrap_samples(
412    n_samples: usize,
413    sample_size: usize,
414    n_bootstrap_rounds: usize,
415    random_seed: Option<u64>,
416) -> Result<Vec<Vec<usize>>> {
417    if n_bootstrap_rounds == 0 {
418        return Err(DatasetsError::InvalidFormat(
419            "Number of bootstrap _rounds must be > 0".to_string(),
420        ));
421    }
422
423    let mut rng = match random_seed {
424        Some(_seed) => StdRng::seed_from_u64(_seed),
425        None => {
426            let mut r = thread_rng();
427            StdRng::seed_from_u64(r.next_u64())
428        }
429    };
430
431    let mut bootstrap_samples = Vec::with_capacity(n_bootstrap_rounds);
432
433    for _ in 0..n_bootstrap_rounds {
434        let sample = random_sample(n_samples, sample_size, true, Some(rng.next_u64()))?;
435        bootstrap_samples.push(sample);
436    }
437
438    Ok(bootstrap_samples)
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use scirs2_core::ndarray::array;
445    use scirs2_core::random::Uniform;
446    use std::collections::HashSet;
447
448    #[test]
449    fn test_random_sample_without_replacement() {
450        let indices = random_sample(10, 5, false, Some(42)).unwrap();
451
452        assert_eq!(indices.len(), 5);
453        assert!(indices.iter().all(|&i| i < 10));
454
455        // All indices should be unique (no replacement)
456        let unique_indices: HashSet<_> = indices.iter().cloned().collect();
457        assert_eq!(unique_indices.len(), 5);
458    }
459
460    #[test]
461    fn test_random_sample_with_replacement() {
462        let indices = random_sample(5, 10, true, Some(42)).unwrap();
463
464        assert_eq!(indices.len(), 10);
465        assert!(indices.iter().all(|&i| i < 5));
466
467        // Some indices might be repeated (with replacement)
468        let unique_indices: HashSet<_> = indices.iter().cloned().collect();
469        assert!(unique_indices.len() <= 10);
470    }
471
472    #[test]
473    fn test_random_sample_invalid_params() {
474        // Zero samples
475        assert!(random_sample(0, 5, false, None).is_err());
476
477        // Zero sample size
478        assert!(random_sample(10, 0, false, None).is_err());
479
480        // Too many samples without replacement
481        assert!(random_sample(5, 10, false, None).is_err());
482    }
483
484    #[test]
485    fn test_stratified_sample() {
486        let targets = array![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]; // 2, 3, 3 samples per class
487        let indices = stratified_sample(&targets, 6, Some(42)).unwrap();
488
489        assert_eq!(indices.len(), 6);
490
491        // Count samples per class in the result
492        let mut class_counts = HashMap::new();
493        for &idx in &indices {
494            let class = targets[idx] as i32;
495            *class_counts.entry(class).or_insert(0) += 1;
496        }
497
498        // Should maintain rough proportions
499        assert!(class_counts.len() <= 3); // At most 3 classes
500    }
501
502    #[test]
503    fn test_stratified_sample_insufficient_samples() {
504        let targets = array![0.0, 1.0]; // Only 1 sample per class
505                                        // Requesting 4 samples but only 2 total
506        assert!(stratified_sample(&targets, 4, Some(42)).is_err());
507    }
508
509    #[test]
510    fn test_importance_sample() {
511        let weights = array![0.1, 0.1, 0.1, 0.8, 0.9, 1.0]; // Higher weights at the end
512        let indices = importance_sample(&weights, 3, false, Some(42)).unwrap();
513
514        assert_eq!(indices.len(), 3);
515        assert!(indices.iter().all(|&i| i < 6));
516
517        // All indices should be unique (no replacement)
518        let unique_indices: HashSet<_> = indices.iter().cloned().collect();
519        assert_eq!(unique_indices.len(), 3);
520    }
521
522    #[test]
523    fn test_importance_sample_negative_weights() {
524        let weights = array![0.5, -0.1, 0.3]; // Contains negative weight
525        assert!(importance_sample(&weights, 2, false, None).is_err());
526    }
527
528    #[test]
529    fn test_importance_sample_zero_weights() {
530        let weights = array![0.0, 0.0, 0.0]; // All zero weights
531        assert!(importance_sample(&weights, 2, false, None).is_err());
532    }
533
534    #[test]
535    fn test_bootstrap_sample() {
536        let indices = bootstrap_sample(20, 20, Some(42)).unwrap();
537
538        assert_eq!(indices.len(), 20);
539        assert!(indices.iter().all(|&i| i < 20));
540
541        // Should likely have some repeated indices
542        let unique_indices: HashSet<_> = indices.iter().cloned().collect();
543        assert!(unique_indices.len() < 20); // Very likely with size 20
544    }
545
546    #[test]
547    fn test_multiple_bootstrap_samples() {
548        let samples = multiple_bootstrap_samples(10, 8, 5, Some(42)).unwrap();
549
550        assert_eq!(samples.len(), 5);
551        assert!(samples.iter().all(|sample| sample.len() == 8));
552        assert!(samples.iter().all(|sample| sample.iter().all(|&i| i < 10)));
553
554        // Different bootstrap samples should be different
555        assert_ne!(samples[0], samples[1]); // Very likely to be different
556    }
557
558    #[test]
559    fn test_multiple_bootstrap_samples_invalid_params() {
560        assert!(multiple_bootstrap_samples(10, 10, 0, None).is_err());
561    }
562}