sampling_demo/
sampling_demo.rs

1//! Sampling and bootstrapping utilities demonstration
2//!
3//! This example demonstrates the use of random sampling and stratified sampling
4//! utilities provided by scirs2-datasets.
5
6use scirs2_core::ndarray::Array1;
7use scirs2_datasets::{load_iris, random_sample, stratified_sample, Dataset};
8
9#[allow(dead_code)]
10fn main() {
11    println!("=== Sampling and Bootstrapping Demonstration ===\n");
12
13    // Load the Iris dataset for demonstration
14    let iris = load_iris().unwrap();
15    let n_samples = iris.n_samples();
16
17    println!("Original Iris dataset:");
18    println!("- Samples: {n_samples}");
19    println!("- Features: {}", iris.n_features());
20
21    if let Some(target) = &iris.target {
22        let class_counts = count_classes(target);
23        println!("- Class distribution: {class_counts:?}\n");
24    }
25
26    // Demonstrate random sampling without replacement
27    println!("=== Random Sampling (without replacement) ===");
28    let samplesize = 30;
29    let random_indices = random_sample(n_samples, samplesize, false, Some(42)).unwrap();
30
31    println!("Sampled {samplesize} indices from {n_samples} total samples");
32    println!(
33        "Sample indices: {:?}",
34        &random_indices[..10.min(random_indices.len())]
35    );
36
37    // Create a subset dataset
38    let sampledata = iris
39        .data
40        .select(scirs2_core::ndarray::Axis(0), &random_indices);
41    let sample_target = iris
42        .target
43        .as_ref()
44        .map(|t| t.select(scirs2_core::ndarray::Axis(0), &random_indices));
45    let sampledataset = Dataset::new(sampledata, sample_target)
46        .with_description("Random sample from Iris dataset".to_string());
47
48    println!(
49        "Random sample dataset: {} samples, {} features",
50        sampledataset.n_samples(),
51        sampledataset.n_features()
52    );
53
54    if let Some(target) = &sampledataset.target {
55        let sample_class_counts = count_classes(target);
56        println!("Sample class distribution: {sample_class_counts:?}\n");
57    }
58
59    // Demonstrate bootstrap sampling (with replacement)
60    println!("=== Bootstrap Sampling (with replacement) ===");
61    let bootstrapsize = 200; // More than original dataset size
62    let bootstrap_indices = random_sample(n_samples, bootstrapsize, true, Some(42)).unwrap();
63
64    println!("Bootstrap sampled {bootstrapsize} indices from {n_samples} total samples");
65    println!(
66        "Bootstrap may have duplicates - first 10 indices: {:?}",
67        &bootstrap_indices[..10]
68    );
69
70    // Count frequency of each index in bootstrap sample
71    let mut index_counts = vec![0; n_samples];
72    for &idx in &bootstrap_indices {
73        index_counts[idx] += 1;
74    }
75    let max_count = *index_counts.iter().max().unwrap();
76    let zero_count = index_counts.iter().filter(|&&count| count == 0).count();
77
78    println!("Bootstrap statistics:");
79    println!("- Maximum frequency of any sample: {max_count}");
80    println!("- Number of original samples not selected: {zero_count}\n");
81
82    // Demonstrate stratified sampling
83    println!("=== Stratified Sampling ===");
84    if let Some(target) = &iris.target {
85        let stratifiedsize = 30;
86        let stratified_indices = stratified_sample(target, stratifiedsize, Some(42)).unwrap();
87
88        println!("Stratified sampled {stratifiedsize} indices maintaining class proportions");
89
90        // Create stratified subset
91        let stratifieddata = iris
92            .data
93            .select(scirs2_core::ndarray::Axis(0), &stratified_indices);
94        let stratified_target = target.select(scirs2_core::ndarray::Axis(0), &stratified_indices);
95        let stratifieddataset = Dataset::new(stratifieddata, Some(stratified_target))
96            .with_description("Stratified sample from Iris dataset".to_string());
97
98        println!(
99            "Stratified sample dataset: {} samples, {} features",
100            stratifieddataset.n_samples(),
101            stratifieddataset.n_features()
102        );
103
104        let stratified_class_counts = count_classes(&stratifieddataset.target.unwrap());
105        println!("Stratified sample class distribution: {stratified_class_counts:?}");
106
107        // Verify proportions are maintained
108        let original_proportions = calculate_proportions(&count_classes(target));
109        let stratified_proportions = calculate_proportions(&stratified_class_counts);
110
111        println!("Class proportion comparison:");
112        for (&class, &original_prop) in &original_proportions {
113            let stratified_prop = stratified_proportions.get(&class).unwrap_or(&0.0);
114            println!(
115                "  Class {}: Original {:.2}%, Stratified {:.2}%",
116                class,
117                original_prop * 100.0,
118                stratified_prop * 100.0
119            );
120        }
121    }
122
123    // Demonstrate practical use case: creating training/validation splits
124    println!("\n=== Practical Example: Multiple Train/Validation Splits ===");
125    for i in 1..=3 {
126        let split_indices = random_sample(n_samples, 100, false, Some(42 + i)).unwrap();
127        let (train_indices, val_indices) = split_indices.split_at(80);
128
129        println!(
130            "Split {}: {} training samples, {} validation samples",
131            i,
132            train_indices.len(),
133            val_indices.len()
134        );
135    }
136
137    println!("\n=== Sampling Demo Complete ===");
138}
139
140/// Count the number of samples in each class
141#[allow(dead_code)]
142fn count_classes(targets: &Array1<f64>) -> std::collections::HashMap<i64, usize> {
143    let mut counts = std::collections::HashMap::new();
144    for &target in targets.iter() {
145        let class = target.round() as i64;
146        *counts.entry(class).or_insert(0) += 1;
147    }
148    counts
149}
150
151/// Calculate class proportions
152#[allow(dead_code)]
153fn calculate_proportions(
154    counts: &std::collections::HashMap<i64, usize>,
155) -> std::collections::HashMap<i64, f64> {
156    let total: usize = counts.values().sum();
157    counts
158        .iter()
159        .map(|(&class, &count)| (class, count as f64 / total as f64))
160        .collect()
161}