pub fn random_oversample(
data: &Array2<f64>,
targets: &Array1<f64>,
random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>
Expand description
Performs random oversampling to balance class distribution
Duplicates samples from minority classes to match the majority class size. This is useful for handling imbalanced datasets in classification problems.
§Arguments
data
- Feature matrix (n_samples, n_features)targets
- Target values for each samplerandom_seed
- Optional random seed for reproducible sampling
§Returns
A tuple containing the resampled (data, targets) arrays
§Examples
use ndarray::{Array1, Array2};
use scirs2_datasets::utils::random_oversample;
let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
let (balanced_data, balanced_targets) = random_oversample(&data, &targets, Some(42)).unwrap();
// Now both classes have 4 samples each
Examples found in repository?
examples/balancing_demo.rs (line 37)
12fn main() {
13 println!("=== Data Balancing Utilities Demonstration ===\n");
14
15 // Create an artificially imbalanced dataset for demonstration
16 let data = Array2::from_shape_vec(
17 (10, 2),
18 vec![
19 // Class 0 (minority): 2 samples
20 1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
21 5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
22 // Class 2 (moderate): 2 samples
23 10.0, 10.0, 10.1, 9.9,
24 ],
25 )
26 .unwrap();
27
28 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
29
30 println!("Original imbalanced dataset:");
31 print_class_distribution(&targets);
32 println!("Total samples: {}\n", data.nrows());
33
34 // Demonstrate random oversampling
35 println!("=== Random Oversampling =======================");
36 let (oversampled_data, oversampled_targets) =
37 random_oversample(&data, &targets, Some(42)).unwrap();
38
39 println!("After random oversampling:");
40 print_class_distribution(&oversampled_targets);
41 println!("Total samples: {}\n", oversampled_data.nrows());
42
43 // Demonstrate random undersampling
44 println!("=== Random Undersampling ======================");
45 let (undersampled_data, undersampled_targets) =
46 random_undersample(&data, &targets, Some(42)).unwrap();
47
48 println!("After random undersampling:");
49 print_class_distribution(&undersampled_targets);
50 println!("Total samples: {}\n", undersampled_data.nrows());
51
52 // Demonstrate SMOTE-like synthetic sample generation
53 println!("=== Synthetic Sample Generation (SMOTE-like) ==");
54
55 // Generate 4 synthetic samples for class 0 (minority class)
56 let (synthetic_data, synthetic_targets) =
57 generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
58
59 println!(
60 "Generated {} synthetic samples for class 0",
61 synthetic_data.nrows()
62 );
63 println!("Synthetic samples (first 3 features of each):");
64 for i in 0..synthetic_data.nrows() {
65 println!(
66 " Sample {}: [{:.3}, {:.3}] -> class {}",
67 i,
68 synthetic_data[[i, 0]],
69 synthetic_data[[i, 1]],
70 synthetic_targets[i]
71 );
72 }
73 println!();
74
75 // Demonstrate unified balancing strategies
76 println!("=== Unified Balancing Strategies ==============");
77
78 // Strategy 1: Random Oversampling
79 let (balanced_over, targets_over) = create_balanced_dataset(
80 &data,
81 &targets,
82 BalancingStrategy::RandomOversample,
83 Some(42),
84 )
85 .unwrap();
86
87 println!("Strategy: Random Oversampling");
88 print_class_distribution(&targets_over);
89 println!("Total samples: {}", balanced_over.nrows());
90
91 // Strategy 2: Random Undersampling
92 let (balanced_under, targets_under) = create_balanced_dataset(
93 &data,
94 &targets,
95 BalancingStrategy::RandomUndersample,
96 Some(42),
97 )
98 .unwrap();
99
100 println!("\nStrategy: Random Undersampling");
101 print_class_distribution(&targets_under);
102 println!("Total samples: {}", balanced_under.nrows());
103
104 // Strategy 3: SMOTE with k=1 neighbors
105 let (balanced_smote, targets_smote) = create_balanced_dataset(
106 &data,
107 &targets,
108 BalancingStrategy::SMOTE { k_neighbors: 1 },
109 Some(42),
110 )
111 .unwrap();
112
113 println!("\nStrategy: SMOTE (k_neighbors=1)");
114 print_class_distribution(&targets_smote);
115 println!("Total samples: {}", balanced_smote.nrows());
116
117 // Demonstrate with real-world dataset
118 println!("\n=== Real-world Example: Iris Dataset ==========");
119
120 let iris = load_iris().unwrap();
121 if let Some(iris_targets) = &iris.target {
122 println!("Original Iris dataset:");
123 print_class_distribution(iris_targets);
124
125 // Apply oversampling to iris (it's already balanced, but for demonstration)
126 let (iris_balanced, iris_balanced_targets) =
127 random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
128
129 println!("\nIris after oversampling (should remain the same):");
130 print_class_distribution(&iris_balanced_targets);
131 println!("Total samples: {}", iris_balanced.nrows());
132
133 // Create artificial imbalance by removing some samples
134 let indices_to_keep: Vec<usize> = (0..150)
135 .filter(|&i| {
136 let class = iris_targets[i].round() as i64;
137 // Keep all of class 0, 30 of class 1, 10 of class 2
138 match class {
139 0 => true, // Keep all 50
140 1 => i < 80, // Keep first 30 (indices 50-79)
141 2 => i < 110, // Keep first 10 (indices 100-109)
142 _ => false,
143 }
144 })
145 .collect();
146
147 let imbalanced_data = iris.data.select(ndarray::Axis(0), &indices_to_keep);
148 let imbalanced_targets = iris_targets.select(ndarray::Axis(0), &indices_to_keep);
149
150 println!("\nArtificially imbalanced Iris:");
151 print_class_distribution(&imbalanced_targets);
152
153 // Balance it using SMOTE
154 let (rebalanced_data, rebalanced_targets) = create_balanced_dataset(
155 &imbalanced_data,
156 &imbalanced_targets,
157 BalancingStrategy::SMOTE { k_neighbors: 3 },
158 Some(42),
159 )
160 .unwrap();
161
162 println!("\nAfter SMOTE rebalancing:");
163 print_class_distribution(&rebalanced_targets);
164 println!("Total samples: {}", rebalanced_data.nrows());
165 }
166
167 println!("\n=== Performance Comparison ====================");
168
169 // Show the tradeoffs between different strategies
170 println!("Strategy Comparison Summary:");
171 println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
172 println!("│ Strategy │ Final Size │ Characteristics │");
173 println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
174 println!(
175 "│ Random Oversample │ {} samples │ Increases data size, duplicates │",
176 balanced_over.nrows()
177 );
178 println!(
179 "│ Random Undersample │ {} samples │ Reduces data size, loses info │",
180 balanced_under.nrows()
181 );
182 println!(
183 "│ SMOTE │ {} samples │ Increases size, synthetic data │",
184 balanced_smote.nrows()
185 );
186 println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
187
188 println!("\n=== Balancing Demo Complete ====================");
189}