pub fn stratified_sample(
targets: &Array1<f64>,
sample_size: usize,
random_seed: Option<u64>,
) -> Result<Vec<usize>>
Expand description
Performs stratified random sampling
Maintains the same class distribution in the sample as in the original dataset. This is particularly useful for classification tasks where you want to ensure that all classes are represented proportionally in your sample.
§Arguments
targets
- Target values for stratificationsample_size
- Number of samples to drawrandom_seed
- Optional random seed for reproducible sampling
§Returns
A vector of indices representing the stratified sample
§Examples
use ndarray::Array1;
use scirs2_datasets::utils::stratified_sample;
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
let indices = stratified_sample(&targets, 6, Some(42)).unwrap();
assert_eq!(indices.len(), 6);
// Check that the sample maintains class proportions
let mut class_counts = std::collections::HashMap::new();
for &idx in &indices {
let class = targets[idx] as i32;
*class_counts.entry(class).or_insert(0) += 1;
}
Examples found in repository?
examples/sampling_demo.rs (line 92)
9fn main() {
10 println!("=== Sampling and Bootstrapping Demonstration ===\n");
11
12 // Load the Iris dataset for demonstration
13 let iris = load_iris().unwrap();
14 let n_samples = iris.n_samples();
15
16 println!("Original Iris dataset:");
17 println!("- Samples: {}", n_samples);
18 println!("- Features: {}", iris.n_features());
19
20 if let Some(target) = &iris.target {
21 let class_counts = count_classes(target);
22 println!("- Class distribution: {:?}\n", class_counts);
23 }
24
25 // Demonstrate random sampling without replacement
26 println!("=== Random Sampling (without replacement) ===");
27 let sample_size = 30;
28 let random_indices = random_sample(n_samples, sample_size, false, Some(42)).unwrap();
29
30 println!(
31 "Sampled {} indices from {} total samples",
32 sample_size, n_samples
33 );
34 println!(
35 "Sample indices: {:?}",
36 &random_indices[..10.min(random_indices.len())]
37 );
38
39 // Create a subset dataset
40 let sample_data = iris.data.select(ndarray::Axis(0), &random_indices);
41 let sample_target = iris
42 .target
43 .as_ref()
44 .map(|t| t.select(ndarray::Axis(0), &random_indices));
45 let sample_dataset = Dataset::new(sample_data, sample_target)
46 .with_description("Random sample from Iris dataset".to_string());
47
48 println!(
49 "Random sample dataset: {} samples, {} features",
50 sample_dataset.n_samples(),
51 sample_dataset.n_features()
52 );
53
54 if let Some(target) = &sample_dataset.target {
55 let sample_class_counts = count_classes(target);
56 println!("Sample class distribution: {:?}\n", sample_class_counts);
57 }
58
59 // Demonstrate bootstrap sampling (with replacement)
60 println!("=== Bootstrap Sampling (with replacement) ===");
61 let bootstrap_size = 200; // More than original dataset size
62 let bootstrap_indices = random_sample(n_samples, bootstrap_size, true, Some(42)).unwrap();
63
64 println!(
65 "Bootstrap sampled {} indices from {} total samples",
66 bootstrap_size, n_samples
67 );
68 println!(
69 "Bootstrap may have duplicates - first 10 indices: {:?}",
70 &bootstrap_indices[..10]
71 );
72
73 // Count frequency of each index in bootstrap sample
74 let mut index_counts = vec![0; n_samples];
75 for &idx in &bootstrap_indices {
76 index_counts[idx] += 1;
77 }
78 let max_count = *index_counts.iter().max().unwrap();
79 let zero_count = index_counts.iter().filter(|&&count| count == 0).count();
80
81 println!("Bootstrap statistics:");
82 println!("- Maximum frequency of any sample: {}", max_count);
83 println!(
84 "- Number of original samples not selected: {}\n",
85 zero_count
86 );
87
88 // Demonstrate stratified sampling
89 println!("=== Stratified Sampling ===");
90 if let Some(target) = &iris.target {
91 let stratified_size = 30;
92 let stratified_indices = stratified_sample(target, stratified_size, Some(42)).unwrap();
93
94 println!(
95 "Stratified sampled {} indices maintaining class proportions",
96 stratified_size
97 );
98
99 // Create stratified subset
100 let stratified_data = iris.data.select(ndarray::Axis(0), &stratified_indices);
101 let stratified_target = target.select(ndarray::Axis(0), &stratified_indices);
102 let stratified_dataset = Dataset::new(stratified_data, Some(stratified_target))
103 .with_description("Stratified sample from Iris dataset".to_string());
104
105 println!(
106 "Stratified sample dataset: {} samples, {} features",
107 stratified_dataset.n_samples(),
108 stratified_dataset.n_features()
109 );
110
111 let stratified_class_counts = count_classes(&stratified_dataset.target.unwrap());
112 println!(
113 "Stratified sample class distribution: {:?}",
114 stratified_class_counts
115 );
116
117 // Verify proportions are maintained
118 let original_proportions = calculate_proportions(&count_classes(target));
119 let stratified_proportions = calculate_proportions(&stratified_class_counts);
120
121 println!("Class proportion comparison:");
122 for (&class, &original_prop) in &original_proportions {
123 let stratified_prop = stratified_proportions.get(&class).unwrap_or(&0.0);
124 println!(
125 " Class {}: Original {:.2}%, Stratified {:.2}%",
126 class,
127 original_prop * 100.0,
128 stratified_prop * 100.0
129 );
130 }
131 }
132
133 // Demonstrate practical use case: creating training/validation splits
134 println!("\n=== Practical Example: Multiple Train/Validation Splits ===");
135 for i in 1..=3 {
136 let split_indices = random_sample(n_samples, 100, false, Some(42 + i)).unwrap();
137 let (train_indices, val_indices) = split_indices.split_at(80);
138
139 println!(
140 "Split {}: {} training samples, {} validation samples",
141 i,
142 train_indices.len(),
143 val_indices.len()
144 );
145 }
146
147 println!("\n=== Sampling Demo Complete ===");
148}