pub fn generate_synthetic_samples(
data: &Array2<f64>,
targets: &Array1<f64>,
target_class: f64,
n_synthetic: usize,
k_neighbors: usize,
random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>Expand description
Generates synthetic samples using SMOTE-like interpolation
Creates synthetic samples by interpolating between existing samples within each class. This is useful for oversampling minority classes without simple duplication.
§Arguments
data- Feature matrix (n_samples, n_features)targets- Target values for each sampletarget_class- The class to generate synthetic samples forn_synthetic- Number of synthetic samples to generatek_neighbors- Number of nearest neighbors to consider for interpolationrandom_seed- Optional random seed for reproducible generation
§Returns
A tuple containing the synthetic (data, targets) arrays
§Examples
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_datasets::utils::generate_synthetic_samples;
let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
let (synthetic_data, synthetic_targets) = generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).unwrap();
assert_eq!(synthetic_data.nrows(), 2);
assert_eq!(synthetic_targets.len(), 2);Examples found in repository?
examples/balancing_demo.rs (line 58)
13fn main() {
14 println!("=== Data Balancing Utilities Demonstration ===\n");
15
16 // Create an artificially imbalanced dataset for demonstration
17 let data = Array2::from_shape_vec(
18 (10, 2),
19 vec![
20 // Class 0 (minority): 2 samples
21 1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
22 5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
23 // Class 2 (moderate): 2 samples
24 10.0, 10.0, 10.1, 9.9,
25 ],
26 )
27 .unwrap();
28
29 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
30
31 println!("Original imbalanced dataset:");
32 print_class_distribution(&targets);
33 println!("Total samples: {}\n", data.nrows());
34
35 // Demonstrate random oversampling
36 println!("=== Random Oversampling =======================");
37 let (oversampleddata, oversampled_targets) =
38 random_oversample(&data, &targets, Some(42)).unwrap();
39
40 println!("After random oversampling:");
41 print_class_distribution(&oversampled_targets);
42 println!("Total samples: {}\n", oversampleddata.nrows());
43
44 // Demonstrate random undersampling
45 println!("=== Random Undersampling ======================");
46 let (undersampleddata, undersampled_targets) =
47 random_undersample(&data, &targets, Some(42)).unwrap();
48
49 println!("After random undersampling:");
50 print_class_distribution(&undersampled_targets);
51 println!("Total samples: {}\n", undersampleddata.nrows());
52
53 // Demonstrate SMOTE-like synthetic sample generation
54 println!("=== Synthetic Sample Generation (SMOTE-like) ==");
55
56 // Generate 4 synthetic samples for class 0 (minority class)
57 let (syntheticdata, synthetic_targets) =
58 generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
59
60 println!(
61 "Generated {} synthetic samples for class 0",
62 syntheticdata.nrows()
63 );
64 println!("Synthetic samples (first 3 features of each):");
65 for i in 0..syntheticdata.nrows() {
66 println!(
67 " Sample {}: [{:.3}, {:.3}] -> class {}",
68 i,
69 syntheticdata[[i, 0]],
70 syntheticdata[[i, 1]],
71 synthetic_targets[i]
72 );
73 }
74 println!();
75
76 // Demonstrate unified balancing strategies
77 println!("=== Unified Balancing Strategies ==============");
78
79 // Strategy 1: Random Oversampling
80 let (balanced_over, targets_over) = create_balanced_dataset(
81 &data,
82 &targets,
83 BalancingStrategy::RandomOversample,
84 Some(42),
85 )
86 .unwrap();
87
88 println!("Strategy: Random Oversampling");
89 print_class_distribution(&targets_over);
90 println!("Total samples: {}", balanced_over.nrows());
91
92 // Strategy 2: Random Undersampling
93 let (balanced_under, targets_under) = create_balanced_dataset(
94 &data,
95 &targets,
96 BalancingStrategy::RandomUndersample,
97 Some(42),
98 )
99 .unwrap();
100
101 println!("\nStrategy: Random Undersampling");
102 print_class_distribution(&targets_under);
103 println!("Total samples: {}", balanced_under.nrows());
104
105 // Strategy 3: SMOTE with k=1 neighbors
106 let (balanced_smote, targets_smote) = create_balanced_dataset(
107 &data,
108 &targets,
109 BalancingStrategy::SMOTE { k_neighbors: 1 },
110 Some(42),
111 )
112 .unwrap();
113
114 println!("\nStrategy: SMOTE (k_neighbors=1)");
115 print_class_distribution(&targets_smote);
116 println!("Total samples: {}", balanced_smote.nrows());
117
118 // Demonstrate with real-world dataset
119 println!("\n=== Real-world Example: Iris Dataset ==========");
120
121 let iris = load_iris().unwrap();
122 if let Some(iris_targets) = &iris.target {
123 println!("Original Iris dataset:");
124 print_class_distribution(iris_targets);
125
126 // Apply oversampling to iris (it's already balanced, but for demonstration)
127 let (iris_balanced, iris_balanced_targets) =
128 random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
129
130 println!("\nIris after oversampling (should remain the same):");
131 print_class_distribution(&iris_balanced_targets);
132 println!("Total samples: {}", iris_balanced.nrows());
133
134 // Create artificial imbalance by removing some samples
135 let indices_to_keep: Vec<usize> = (0..150)
136 .filter(|&i| {
137 let class = iris_targets[i].round() as i64;
138 // Keep all of class 0, 30 of class 1, 10 of class 2
139 match class {
140 0 => true, // Keep all 50
141 1 => i < 80, // Keep first 30 (indices 50-79)
142 2 => i < 110, // Keep first 10 (indices 100-109)
143 _ => false,
144 }
145 })
146 .collect();
147
148 let imbalanceddata = iris
149 .data
150 .select(scirs2_core::ndarray::Axis(0), &indices_to_keep);
151 let imbalanced_targets =
152 iris_targets.select(scirs2_core::ndarray::Axis(0), &indices_to_keep);
153
154 println!("\nArtificially imbalanced Iris:");
155 print_class_distribution(&imbalanced_targets);
156
157 // Balance it using SMOTE
158 let (rebalanceddata, rebalanced_targets) = create_balanced_dataset(
159 &imbalanceddata,
160 &imbalanced_targets,
161 BalancingStrategy::SMOTE { k_neighbors: 3 },
162 Some(42),
163 )
164 .unwrap();
165
166 println!("\nAfter SMOTE rebalancing:");
167 print_class_distribution(&rebalanced_targets);
168 println!("Total samples: {}", rebalanceddata.nrows());
169 }
170
171 println!("\n=== Performance Comparison ====================");
172
173 // Show the tradeoffs between different strategies
174 println!("Strategy Comparison Summary:");
175 println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
176 println!("│ Strategy │ Final Size │ Characteristics │");
177 println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
178 println!(
179 "│ Random Oversample │ {} samples │ Increases data size, duplicates │",
180 balanced_over.nrows()
181 );
182 println!(
183 "│ Random Undersample │ {} samples │ Reduces data size, loses info │",
184 balanced_under.nrows()
185 );
186 println!(
187 "│ SMOTE │ {} samples │ Increases size, synthetic data │",
188 balanced_smote.nrows()
189 );
190 println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
191
192 println!("\n=== Balancing Demo Complete ====================");
193}