scirs2_datasets/utils/
sampling.rs

1//! Data sampling utilities for statistical analysis and machine learning
2//!
3//! This module provides various sampling strategies including random sampling,
4//! stratified sampling, and importance-weighted sampling. These functions are
5//! useful for creating representative subsets of datasets, bootstrap sampling,
6//! and handling imbalanced data distributions.
7
8use crate::error::{DatasetsError, Result};
9use ndarray::Array1;
10use rand::prelude::*;
11use rand::rng;
12use rand::rngs::StdRng;
13use std::collections::HashMap;
14
15/// Performs random sampling with or without replacement
16///
17/// This function creates random samples from a dataset using either bootstrap
18/// sampling (with replacement) or standard random sampling (without replacement).
19///
20/// # Arguments
21///
22/// * `n_samples` - Total number of samples in the dataset
23/// * `sample_size` - Number of samples to draw
24/// * `replace` - Whether to sample with replacement (bootstrap)
25/// * `random_seed` - Optional random seed for reproducible sampling
26///
27/// # Returns
28///
29/// A vector of indices representing the sampled data points
30///
31/// # Examples
32///
33/// ```rust
34/// use scirs2_datasets::utils::random_sample;
35///
36/// // Sample 5 indices from 10 total samples without replacement
37/// let indices = random_sample(10, 5, false, Some(42)).unwrap();
38/// assert_eq!(indices.len(), 5);
39/// assert!(indices.iter().all(|&i| i < 10));
40///
41/// // Bootstrap sampling (with replacement)
42/// let bootstrap_indices = random_sample(10, 15, true, Some(42)).unwrap();
43/// assert_eq!(bootstrap_indices.len(), 15);
44/// ```
45pub fn random_sample(
46    n_samples: usize,
47    sample_size: usize,
48    replace: bool,
49    random_seed: Option<u64>,
50) -> Result<Vec<usize>> {
51    if n_samples == 0 {
52        return Err(DatasetsError::InvalidFormat(
53            "Number of samples must be > 0".to_string(),
54        ));
55    }
56
57    if sample_size == 0 {
58        return Err(DatasetsError::InvalidFormat(
59            "Sample size must be > 0".to_string(),
60        ));
61    }
62
63    if !replace && sample_size > n_samples {
64        return Err(DatasetsError::InvalidFormat(format!(
65            "Cannot sample {} items from {} without replacement",
66            sample_size, n_samples
67        )));
68    }
69
70    let mut rng = match random_seed {
71        Some(seed) => StdRng::seed_from_u64(seed),
72        None => {
73            let mut r = rng();
74            StdRng::seed_from_u64(r.next_u64())
75        }
76    };
77
78    let mut indices = Vec::with_capacity(sample_size);
79
80    if replace {
81        // Bootstrap sampling (with replacement)
82        for _ in 0..sample_size {
83            indices.push(rng.random_range(0..n_samples));
84        }
85    } else {
86        // Sampling without replacement
87        let mut available: Vec<usize> = (0..n_samples).collect();
88        available.shuffle(&mut rng);
89        indices.extend_from_slice(&available[0..sample_size]);
90    }
91
92    Ok(indices)
93}
94
95/// Performs stratified random sampling
96///
97/// Maintains the same class distribution in the sample as in the original dataset.
98/// This is particularly useful for classification tasks where you want to ensure
99/// that all classes are represented proportionally in your sample.
100///
101/// # Arguments
102///
103/// * `targets` - Target values for stratification
104/// * `sample_size` - Number of samples to draw
105/// * `random_seed` - Optional random seed for reproducible sampling
106///
107/// # Returns
108///
109/// A vector of indices representing the stratified sample
110///
111/// # Examples
112///
113/// ```rust
114/// use ndarray::Array1;
115/// use scirs2_datasets::utils::stratified_sample;
116///
117/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
118/// let indices = stratified_sample(&targets, 6, Some(42)).unwrap();
119/// assert_eq!(indices.len(), 6);
120///
121/// // Check that the sample maintains class proportions
122/// let mut class_counts = std::collections::HashMap::new();
123/// for &idx in &indices {
124///     let class = targets[idx] as i32;
125///     *class_counts.entry(class).or_insert(0) += 1;
126/// }
127/// ```
128pub fn stratified_sample(
129    targets: &Array1<f64>,
130    sample_size: usize,
131    random_seed: Option<u64>,
132) -> Result<Vec<usize>> {
133    if targets.is_empty() {
134        return Err(DatasetsError::InvalidFormat(
135            "Targets array cannot be empty".to_string(),
136        ));
137    }
138
139    if sample_size == 0 {
140        return Err(DatasetsError::InvalidFormat(
141            "Sample size must be > 0".to_string(),
142        ));
143    }
144
145    if sample_size > targets.len() {
146        return Err(DatasetsError::InvalidFormat(format!(
147            "Cannot sample {} items from {} total samples",
148            sample_size,
149            targets.len()
150        )));
151    }
152
153    // Group indices by target class
154    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
155    for (i, &target) in targets.iter().enumerate() {
156        let class = target.round() as i64;
157        class_indices.entry(class).or_default().push(i);
158    }
159
160    let mut rng = match random_seed {
161        Some(seed) => StdRng::seed_from_u64(seed),
162        None => {
163            let mut r = rng();
164            StdRng::seed_from_u64(r.next_u64())
165        }
166    };
167
168    let mut stratified_indices = Vec::new();
169    let n_classes = class_indices.len();
170    let base_samples_per_class = sample_size / n_classes;
171    let remainder = sample_size % n_classes;
172
173    let mut class_list: Vec<_> = class_indices.keys().cloned().collect();
174    class_list.sort();
175
176    for (i, &class) in class_list.iter().enumerate() {
177        let class_samples = class_indices.get(&class).unwrap();
178        let samples_for_this_class = if i < remainder {
179            base_samples_per_class + 1
180        } else {
181            base_samples_per_class
182        };
183
184        if samples_for_this_class > class_samples.len() {
185            return Err(DatasetsError::InvalidFormat(format!(
186                "Class {} has only {} samples but needs {} for stratified sampling",
187                class,
188                class_samples.len(),
189                samples_for_this_class
190            )));
191        }
192
193        // Sample from this class
194        let sampled_indices = random_sample(
195            class_samples.len(),
196            samples_for_this_class,
197            false,
198            Some(rng.next_u64()),
199        )?;
200
201        for &idx in &sampled_indices {
202            stratified_indices.push(class_samples[idx]);
203        }
204    }
205
206    stratified_indices.shuffle(&mut rng);
207    Ok(stratified_indices)
208}
209
210/// Performs importance sampling based on provided weights
211///
212/// Samples indices according to the provided probability weights. Higher weights
213/// increase the probability of selection. This is useful for adaptive sampling
214/// where some samples are more important than others for training.
215///
216/// # Arguments
217///
218/// * `weights` - Probability weights for each sample (must be non-negative)
219/// * `sample_size` - Number of samples to draw
220/// * `replace` - Whether to sample with replacement
221/// * `random_seed` - Optional random seed for reproducible sampling
222///
223/// # Returns
224///
225/// A vector of indices representing the importance-weighted sample
226///
227/// # Examples
228///
229/// ```rust
230/// use ndarray::Array1;
231/// use scirs2_datasets::utils::importance_sample;
232///
233/// // Give higher weights to the last few samples
234/// let weights = Array1::from(vec![0.1, 0.1, 0.1, 0.8, 0.9, 1.0]);
235/// let indices = importance_sample(&weights, 3, false, Some(42)).unwrap();
236/// assert_eq!(indices.len(), 3);
237///
238/// // Higher weighted samples should be more likely to be selected
239/// let mut high_weight_count = 0;
240/// for &idx in &indices {
241///     if idx >= 3 { // Last three samples have higher weights
242///         high_weight_count += 1;
243///     }
244/// }
245/// // This should be true with high probability
246/// assert!(high_weight_count >= 1);
247/// ```
248pub fn importance_sample(
249    weights: &Array1<f64>,
250    sample_size: usize,
251    replace: bool,
252    random_seed: Option<u64>,
253) -> Result<Vec<usize>> {
254    if weights.is_empty() {
255        return Err(DatasetsError::InvalidFormat(
256            "Weights array cannot be empty".to_string(),
257        ));
258    }
259
260    if sample_size == 0 {
261        return Err(DatasetsError::InvalidFormat(
262            "Sample size must be > 0".to_string(),
263        ));
264    }
265
266    if !replace && sample_size > weights.len() {
267        return Err(DatasetsError::InvalidFormat(format!(
268            "Cannot sample {} items from {} without replacement",
269            sample_size,
270            weights.len()
271        )));
272    }
273
274    // Check for negative weights
275    for &weight in weights.iter() {
276        if weight < 0.0 {
277            return Err(DatasetsError::InvalidFormat(
278                "All weights must be non-negative".to_string(),
279            ));
280        }
281    }
282
283    let weight_sum: f64 = weights.sum();
284    if weight_sum <= 0.0 {
285        return Err(DatasetsError::InvalidFormat(
286            "Sum of weights must be positive".to_string(),
287        ));
288    }
289
290    let mut rng = match random_seed {
291        Some(seed) => StdRng::seed_from_u64(seed),
292        None => {
293            let mut r = rng();
294            StdRng::seed_from_u64(r.next_u64())
295        }
296    };
297
298    let mut indices = Vec::with_capacity(sample_size);
299    let mut available_weights = weights.clone();
300    let mut available_indices: Vec<usize> = (0..weights.len()).collect();
301
302    for _ in 0..sample_size {
303        let current_sum = available_weights.sum();
304        if current_sum <= 0.0 {
305            break;
306        }
307
308        // Generate random number between 0 and current_sum
309        let random_value = rng.random_range(0.0..current_sum);
310
311        // Find the index corresponding to this random value
312        let mut cumulative_weight = 0.0;
313        let mut selected_idx = 0;
314
315        for (i, &weight) in available_weights.iter().enumerate() {
316            cumulative_weight += weight;
317            if random_value <= cumulative_weight {
318                selected_idx = i;
319                break;
320            }
321        }
322
323        let original_idx = available_indices[selected_idx];
324        indices.push(original_idx);
325
326        if !replace {
327            // Remove the selected item for sampling without replacement
328            available_weights = Array1::from_iter(
329                available_weights
330                    .iter()
331                    .enumerate()
332                    .filter(|(i, _)| *i != selected_idx)
333                    .map(|(_, &w)| w),
334            );
335            available_indices.remove(selected_idx);
336        }
337    }
338
339    Ok(indices)
340}
341
342/// Generate bootstrap samples from indices
343///
344/// This is a convenience function that generates bootstrap samples (sampling with
345/// replacement) which is commonly used for bootstrap confidence intervals and
346/// ensemble methods.
347///
348/// # Arguments
349///
350/// * `n_samples` - Total number of samples in the dataset
351/// * `n_bootstrap_samples` - Number of bootstrap samples to generate
352/// * `random_seed` - Optional random seed for reproducible sampling
353///
354/// # Returns
355///
356/// A vector of bootstrap sample indices
357///
358/// # Examples
359///
360/// ```rust
361/// use scirs2_datasets::utils::bootstrap_sample;
362///
363/// let bootstrap_indices = bootstrap_sample(100, 100, Some(42)).unwrap();
364/// assert_eq!(bootstrap_indices.len(), 100);
365///
366/// // Some indices should appear multiple times (with high probability)
367/// let mut unique_indices = bootstrap_indices.clone();
368/// unique_indices.sort();
369/// unique_indices.dedup();
370/// assert!(unique_indices.len() < bootstrap_indices.len());
371/// ```
372pub fn bootstrap_sample(
373    n_samples: usize,
374    n_bootstrap_samples: usize,
375    random_seed: Option<u64>,
376) -> Result<Vec<usize>> {
377    random_sample(n_samples, n_bootstrap_samples, true, random_seed)
378}
379
380/// Generate multiple bootstrap samples
381///
382/// Creates multiple independent bootstrap samples, useful for ensemble methods
383/// like bagging or for computing bootstrap confidence intervals.
384///
385/// # Arguments
386///
387/// * `n_samples` - Total number of samples in the dataset
388/// * `sample_size` - Size of each bootstrap sample
389/// * `n_bootstrap_rounds` - Number of bootstrap samples to generate
390/// * `random_seed` - Optional random seed for reproducible sampling
391///
392/// # Returns
393///
394/// A vector of bootstrap sample vectors
395///
396/// # Examples
397///
398/// ```rust
399/// use scirs2_datasets::utils::multiple_bootstrap_samples;
400///
401/// let bootstrap_samples = multiple_bootstrap_samples(50, 50, 10, Some(42)).unwrap();
402/// assert_eq!(bootstrap_samples.len(), 10);
403/// assert!(bootstrap_samples.iter().all(|sample| sample.len() == 50));
404/// ```
405pub fn multiple_bootstrap_samples(
406    n_samples: usize,
407    sample_size: usize,
408    n_bootstrap_rounds: usize,
409    random_seed: Option<u64>,
410) -> Result<Vec<Vec<usize>>> {
411    if n_bootstrap_rounds == 0 {
412        return Err(DatasetsError::InvalidFormat(
413            "Number of bootstrap rounds must be > 0".to_string(),
414        ));
415    }
416
417    let mut rng = match random_seed {
418        Some(seed) => StdRng::seed_from_u64(seed),
419        None => {
420            let mut r = rng();
421            StdRng::seed_from_u64(r.next_u64())
422        }
423    };
424
425    let mut bootstrap_samples = Vec::with_capacity(n_bootstrap_rounds);
426
427    for _ in 0..n_bootstrap_rounds {
428        let sample = random_sample(n_samples, sample_size, true, Some(rng.next_u64()))?;
429        bootstrap_samples.push(sample);
430    }
431
432    Ok(bootstrap_samples)
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use ndarray::array;
439    use std::collections::HashSet;
440
441    #[test]
442    fn test_random_sample_without_replacement() {
443        let indices = random_sample(10, 5, false, Some(42)).unwrap();
444
445        assert_eq!(indices.len(), 5);
446        assert!(indices.iter().all(|&i| i < 10));
447
448        // All indices should be unique (no replacement)
449        let unique_indices: HashSet<_> = indices.iter().cloned().collect();
450        assert_eq!(unique_indices.len(), 5);
451    }
452
453    #[test]
454    fn test_random_sample_with_replacement() {
455        let indices = random_sample(5, 10, true, Some(42)).unwrap();
456
457        assert_eq!(indices.len(), 10);
458        assert!(indices.iter().all(|&i| i < 5));
459
460        // Some indices might be repeated (with replacement)
461        let unique_indices: HashSet<_> = indices.iter().cloned().collect();
462        assert!(unique_indices.len() <= 10);
463    }
464
465    #[test]
466    fn test_random_sample_invalid_params() {
467        // Zero samples
468        assert!(random_sample(0, 5, false, None).is_err());
469
470        // Zero sample size
471        assert!(random_sample(10, 0, false, None).is_err());
472
473        // Too many samples without replacement
474        assert!(random_sample(5, 10, false, None).is_err());
475    }
476
477    #[test]
478    fn test_stratified_sample() {
479        let targets = array![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]; // 2, 3, 3 samples per class
480        let indices = stratified_sample(&targets, 6, Some(42)).unwrap();
481
482        assert_eq!(indices.len(), 6);
483
484        // Count samples per class in the result
485        let mut class_counts = HashMap::new();
486        for &idx in &indices {
487            let class = targets[idx] as i32;
488            *class_counts.entry(class).or_insert(0) += 1;
489        }
490
491        // Should maintain rough proportions
492        assert!(class_counts.len() <= 3); // At most 3 classes
493    }
494
495    #[test]
496    fn test_stratified_sample_insufficient_samples() {
497        let targets = array![0.0, 1.0]; // Only 1 sample per class
498                                        // Requesting 4 samples but only 2 total
499        assert!(stratified_sample(&targets, 4, Some(42)).is_err());
500    }
501
502    #[test]
503    fn test_importance_sample() {
504        let weights = array![0.1, 0.1, 0.1, 0.8, 0.9, 1.0]; // Higher weights at the end
505        let indices = importance_sample(&weights, 3, false, Some(42)).unwrap();
506
507        assert_eq!(indices.len(), 3);
508        assert!(indices.iter().all(|&i| i < 6));
509
510        // All indices should be unique (no replacement)
511        let unique_indices: HashSet<_> = indices.iter().cloned().collect();
512        assert_eq!(unique_indices.len(), 3);
513    }
514
515    #[test]
516    fn test_importance_sample_negative_weights() {
517        let weights = array![0.5, -0.1, 0.3]; // Contains negative weight
518        assert!(importance_sample(&weights, 2, false, None).is_err());
519    }
520
521    #[test]
522    fn test_importance_sample_zero_weights() {
523        let weights = array![0.0, 0.0, 0.0]; // All zero weights
524        assert!(importance_sample(&weights, 2, false, None).is_err());
525    }
526
527    #[test]
528    fn test_bootstrap_sample() {
529        let indices = bootstrap_sample(20, 20, Some(42)).unwrap();
530
531        assert_eq!(indices.len(), 20);
532        assert!(indices.iter().all(|&i| i < 20));
533
534        // Should likely have some repeated indices
535        let unique_indices: HashSet<_> = indices.iter().cloned().collect();
536        assert!(unique_indices.len() < 20); // Very likely with size 20
537    }
538
539    #[test]
540    fn test_multiple_bootstrap_samples() {
541        let samples = multiple_bootstrap_samples(10, 8, 5, Some(42)).unwrap();
542
543        assert_eq!(samples.len(), 5);
544        assert!(samples.iter().all(|sample| sample.len() == 8));
545        assert!(samples.iter().all(|sample| sample.iter().all(|&i| i < 10)));
546
547        // Different bootstrap samples should be different
548        assert_ne!(samples[0], samples[1]); // Very likely to be different
549    }
550
551    #[test]
552    fn test_multiple_bootstrap_samples_invalid_params() {
553        assert!(multiple_bootstrap_samples(10, 10, 0, None).is_err());
554    }
555}