Function stratified_sample

Source
pub fn stratified_sample(
    targets: &Array1<f64>,
    sample_size: usize,
    random_seed: Option<u64>,
) -> Result<Vec<usize>>
Expand description

Performs stratified random sampling

Maintains the same class distribution in the sample as in the original dataset. This is particularly useful for classification tasks where you want to ensure that all classes are represented proportionally in your sample.

§Arguments

  • targets - Target values for stratification
  • sample_size - Number of samples to draw
  • random_seed - Optional random seed for reproducible sampling

§Returns

A vector of indices representing the stratified sample

§Examples

use ndarray::Array1;
use scirs2_datasets::utils::stratified_sample;

let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
let indices = stratified_sample(&targets, 6, Some(42)).unwrap();
assert_eq!(indices.len(), 6);

// Check that the sample maintains class proportions
let mut class_counts = std::collections::HashMap::new();
for &idx in &indices {
    let class = targets[idx] as i32;
    *class_counts.entry(class).or_insert(0) += 1;
}
Examples found in repository?
examples/sampling_demo.rs (line 92)
9fn main() {
10    println!("=== Sampling and Bootstrapping Demonstration ===\n");
11
12    // Load the Iris dataset for demonstration
13    let iris = load_iris().unwrap();
14    let n_samples = iris.n_samples();
15
16    println!("Original Iris dataset:");
17    println!("- Samples: {}", n_samples);
18    println!("- Features: {}", iris.n_features());
19
20    if let Some(target) = &iris.target {
21        let class_counts = count_classes(target);
22        println!("- Class distribution: {:?}\n", class_counts);
23    }
24
25    // Demonstrate random sampling without replacement
26    println!("=== Random Sampling (without replacement) ===");
27    let sample_size = 30;
28    let random_indices = random_sample(n_samples, sample_size, false, Some(42)).unwrap();
29
30    println!(
31        "Sampled {} indices from {} total samples",
32        sample_size, n_samples
33    );
34    println!(
35        "Sample indices: {:?}",
36        &random_indices[..10.min(random_indices.len())]
37    );
38
39    // Create a subset dataset
40    let sample_data = iris.data.select(ndarray::Axis(0), &random_indices);
41    let sample_target = iris
42        .target
43        .as_ref()
44        .map(|t| t.select(ndarray::Axis(0), &random_indices));
45    let sample_dataset = Dataset::new(sample_data, sample_target)
46        .with_description("Random sample from Iris dataset".to_string());
47
48    println!(
49        "Random sample dataset: {} samples, {} features",
50        sample_dataset.n_samples(),
51        sample_dataset.n_features()
52    );
53
54    if let Some(target) = &sample_dataset.target {
55        let sample_class_counts = count_classes(target);
56        println!("Sample class distribution: {:?}\n", sample_class_counts);
57    }
58
59    // Demonstrate bootstrap sampling (with replacement)
60    println!("=== Bootstrap Sampling (with replacement) ===");
61    let bootstrap_size = 200; // More than original dataset size
62    let bootstrap_indices = random_sample(n_samples, bootstrap_size, true, Some(42)).unwrap();
63
64    println!(
65        "Bootstrap sampled {} indices from {} total samples",
66        bootstrap_size, n_samples
67    );
68    println!(
69        "Bootstrap may have duplicates - first 10 indices: {:?}",
70        &bootstrap_indices[..10]
71    );
72
73    // Count frequency of each index in bootstrap sample
74    let mut index_counts = vec![0; n_samples];
75    for &idx in &bootstrap_indices {
76        index_counts[idx] += 1;
77    }
78    let max_count = *index_counts.iter().max().unwrap();
79    let zero_count = index_counts.iter().filter(|&&count| count == 0).count();
80
81    println!("Bootstrap statistics:");
82    println!("- Maximum frequency of any sample: {}", max_count);
83    println!(
84        "- Number of original samples not selected: {}\n",
85        zero_count
86    );
87
88    // Demonstrate stratified sampling
89    println!("=== Stratified Sampling ===");
90    if let Some(target) = &iris.target {
91        let stratified_size = 30;
92        let stratified_indices = stratified_sample(target, stratified_size, Some(42)).unwrap();
93
94        println!(
95            "Stratified sampled {} indices maintaining class proportions",
96            stratified_size
97        );
98
99        // Create stratified subset
100        let stratified_data = iris.data.select(ndarray::Axis(0), &stratified_indices);
101        let stratified_target = target.select(ndarray::Axis(0), &stratified_indices);
102        let stratified_dataset = Dataset::new(stratified_data, Some(stratified_target))
103            .with_description("Stratified sample from Iris dataset".to_string());
104
105        println!(
106            "Stratified sample dataset: {} samples, {} features",
107            stratified_dataset.n_samples(),
108            stratified_dataset.n_features()
109        );
110
111        let stratified_class_counts = count_classes(&stratified_dataset.target.unwrap());
112        println!(
113            "Stratified sample class distribution: {:?}",
114            stratified_class_counts
115        );
116
117        // Verify proportions are maintained
118        let original_proportions = calculate_proportions(&count_classes(target));
119        let stratified_proportions = calculate_proportions(&stratified_class_counts);
120
121        println!("Class proportion comparison:");
122        for (&class, &original_prop) in &original_proportions {
123            let stratified_prop = stratified_proportions.get(&class).unwrap_or(&0.0);
124            println!(
125                "  Class {}: Original {:.2}%, Stratified {:.2}%",
126                class,
127                original_prop * 100.0,
128                stratified_prop * 100.0
129            );
130        }
131    }
132
133    // Demonstrate practical use case: creating training/validation splits
134    println!("\n=== Practical Example: Multiple Train/Validation Splits ===");
135    for i in 1..=3 {
136        let split_indices = random_sample(n_samples, 100, false, Some(42 + i)).unwrap();
137        let (train_indices, val_indices) = split_indices.split_at(80);
138
139        println!(
140            "Split {}: {} training samples, {} validation samples",
141            i,
142            train_indices.len(),
143            val_indices.len()
144        );
145    }
146
147    println!("\n=== Sampling Demo Complete ===");
148}