pub fn random_undersample(
data: &Array2<f64>,
targets: &Array1<f64>,
random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>Expand description
Performs random undersampling to balance class distribution
Randomly removes samples from majority classes to match the minority class size. This reduces the overall dataset size but maintains balance.
§Arguments
data- Feature matrix (n_samples, n_features)targets- Target values for each samplerandom_seed- Optional random seed for reproducible sampling
§Returns
A tuple containing the undersampled (data, targets) arrays
§Examples
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_datasets::utils::random_undersample;
let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
let (balanced_data, balanced_targets) = random_undersample(&data, &targets, Some(42)).unwrap();
// Now both classes have 2 samples eachExamples found in repository?
examples/balancing_demo.rs (line 47)
13fn main() {
14 println!("=== Data Balancing Utilities Demonstration ===\n");
15
16 // Create an artificially imbalanced dataset for demonstration
17 let data = Array2::from_shape_vec(
18 (10, 2),
19 vec![
20 // Class 0 (minority): 2 samples
21 1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
22 5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
23 // Class 2 (moderate): 2 samples
24 10.0, 10.0, 10.1, 9.9,
25 ],
26 )
27 .unwrap();
28
29 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
30
31 println!("Original imbalanced dataset:");
32 print_class_distribution(&targets);
33 println!("Total samples: {}\n", data.nrows());
34
35 // Demonstrate random oversampling
36 println!("=== Random Oversampling =======================");
37 let (oversampleddata, oversampled_targets) =
38 random_oversample(&data, &targets, Some(42)).unwrap();
39
40 println!("After random oversampling:");
41 print_class_distribution(&oversampled_targets);
42 println!("Total samples: {}\n", oversampleddata.nrows());
43
44 // Demonstrate random undersampling
45 println!("=== Random Undersampling ======================");
46 let (undersampleddata, undersampled_targets) =
47 random_undersample(&data, &targets, Some(42)).unwrap();
48
49 println!("After random undersampling:");
50 print_class_distribution(&undersampled_targets);
51 println!("Total samples: {}\n", undersampleddata.nrows());
52
53 // Demonstrate SMOTE-like synthetic sample generation
54 println!("=== Synthetic Sample Generation (SMOTE-like) ==");
55
56 // Generate 4 synthetic samples for class 0 (minority class)
57 let (syntheticdata, synthetic_targets) =
58 generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
59
60 println!(
61 "Generated {} synthetic samples for class 0",
62 syntheticdata.nrows()
63 );
64 println!("Synthetic samples (first 3 features of each):");
65 for i in 0..syntheticdata.nrows() {
66 println!(
67 " Sample {}: [{:.3}, {:.3}] -> class {}",
68 i,
69 syntheticdata[[i, 0]],
70 syntheticdata[[i, 1]],
71 synthetic_targets[i]
72 );
73 }
74 println!();
75
76 // Demonstrate unified balancing strategies
77 println!("=== Unified Balancing Strategies ==============");
78
79 // Strategy 1: Random Oversampling
80 let (balanced_over, targets_over) = create_balanced_dataset(
81 &data,
82 &targets,
83 BalancingStrategy::RandomOversample,
84 Some(42),
85 )
86 .unwrap();
87
88 println!("Strategy: Random Oversampling");
89 print_class_distribution(&targets_over);
90 println!("Total samples: {}", balanced_over.nrows());
91
92 // Strategy 2: Random Undersampling
93 let (balanced_under, targets_under) = create_balanced_dataset(
94 &data,
95 &targets,
96 BalancingStrategy::RandomUndersample,
97 Some(42),
98 )
99 .unwrap();
100
101 println!("\nStrategy: Random Undersampling");
102 print_class_distribution(&targets_under);
103 println!("Total samples: {}", balanced_under.nrows());
104
105 // Strategy 3: SMOTE with k=1 neighbors
106 let (balanced_smote, targets_smote) = create_balanced_dataset(
107 &data,
108 &targets,
109 BalancingStrategy::SMOTE { k_neighbors: 1 },
110 Some(42),
111 )
112 .unwrap();
113
114 println!("\nStrategy: SMOTE (k_neighbors=1)");
115 print_class_distribution(&targets_smote);
116 println!("Total samples: {}", balanced_smote.nrows());
117
118 // Demonstrate with real-world dataset
119 println!("\n=== Real-world Example: Iris Dataset ==========");
120
121 let iris = load_iris().unwrap();
122 if let Some(iris_targets) = &iris.target {
123 println!("Original Iris dataset:");
124 print_class_distribution(iris_targets);
125
126 // Apply oversampling to iris (it's already balanced, but for demonstration)
127 let (iris_balanced, iris_balanced_targets) =
128 random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
129
130 println!("\nIris after oversampling (should remain the same):");
131 print_class_distribution(&iris_balanced_targets);
132 println!("Total samples: {}", iris_balanced.nrows());
133
134 // Create artificial imbalance by removing some samples
135 let indices_to_keep: Vec<usize> = (0..150)
136 .filter(|&i| {
137 let class = iris_targets[i].round() as i64;
138 // Keep all of class 0, 30 of class 1, 10 of class 2
139 match class {
140 0 => true, // Keep all 50
141 1 => i < 80, // Keep first 30 (indices 50-79)
142 2 => i < 110, // Keep first 10 (indices 100-109)
143 _ => false,
144 }
145 })
146 .collect();
147
148 let imbalanceddata = iris
149 .data
150 .select(scirs2_core::ndarray::Axis(0), &indices_to_keep);
151 let imbalanced_targets =
152 iris_targets.select(scirs2_core::ndarray::Axis(0), &indices_to_keep);
153
154 println!("\nArtificially imbalanced Iris:");
155 print_class_distribution(&imbalanced_targets);
156
157 // Balance it using SMOTE
158 let (rebalanceddata, rebalanced_targets) = create_balanced_dataset(
159 &imbalanceddata,
160 &imbalanced_targets,
161 BalancingStrategy::SMOTE { k_neighbors: 3 },
162 Some(42),
163 )
164 .unwrap();
165
166 println!("\nAfter SMOTE rebalancing:");
167 print_class_distribution(&rebalanced_targets);
168 println!("Total samples: {}", rebalanceddata.nrows());
169 }
170
171 println!("\n=== Performance Comparison ====================");
172
173 // Show the tradeoffs between different strategies
174 println!("Strategy Comparison Summary:");
175 println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
176 println!("│ Strategy │ Final Size │ Characteristics │");
177 println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
178 println!(
179 "│ Random Oversample │ {} samples │ Increases data size, duplicates │",
180 balanced_over.nrows()
181 );
182 println!(
183 "│ Random Undersample │ {} samples │ Reduces data size, loses info │",
184 balanced_under.nrows()
185 );
186 println!(
187 "│ SMOTE │ {} samples │ Increases size, synthetic data │",
188 balanced_smote.nrows()
189 );
190 println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
191
192 println!("\n=== Balancing Demo Complete ====================");
193}