pub fn generate_synthetic_samples(
data: &Array2<f64>,
targets: &Array1<f64>,
target_class: f64,
n_synthetic: usize,
k_neighbors: usize,
random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>
Expand description
Generates synthetic samples using SMOTE-like interpolation
Creates synthetic samples by interpolating between existing samples within each class. This is useful for oversampling minority classes without simple duplication.
§Arguments
data
- Feature matrix (n_samples, n_features)targets
- Target values for each sampletarget_class
- The class to generate synthetic samples forn_synthetic
- Number of synthetic samples to generatek_neighbors
- Number of nearest neighbors to consider for interpolationrandom_seed
- Optional random seed for reproducible generation
§Returns
A tuple containing the synthetic (data, targets) arrays
§Examples
use ndarray::{Array1, Array2};
use scirs2_datasets::utils::generate_synthetic_samples;
let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
let (synthetic_data, synthetic_targets) = generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).unwrap();
assert_eq!(synthetic_data.nrows(), 2);
assert_eq!(synthetic_targets.len(), 2);
Examples found in repository?
examples/balancing_demo.rs (line 57)
12fn main() {
13 println!("=== Data Balancing Utilities Demonstration ===\n");
14
15 // Create an artificially imbalanced dataset for demonstration
16 let data = Array2::from_shape_vec(
17 (10, 2),
18 vec![
19 // Class 0 (minority): 2 samples
20 1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
21 5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
22 // Class 2 (moderate): 2 samples
23 10.0, 10.0, 10.1, 9.9,
24 ],
25 )
26 .unwrap();
27
28 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
29
30 println!("Original imbalanced dataset:");
31 print_class_distribution(&targets);
32 println!("Total samples: {}\n", data.nrows());
33
34 // Demonstrate random oversampling
35 println!("=== Random Oversampling =======================");
36 let (oversampled_data, oversampled_targets) =
37 random_oversample(&data, &targets, Some(42)).unwrap();
38
39 println!("After random oversampling:");
40 print_class_distribution(&oversampled_targets);
41 println!("Total samples: {}\n", oversampled_data.nrows());
42
43 // Demonstrate random undersampling
44 println!("=== Random Undersampling ======================");
45 let (undersampled_data, undersampled_targets) =
46 random_undersample(&data, &targets, Some(42)).unwrap();
47
48 println!("After random undersampling:");
49 print_class_distribution(&undersampled_targets);
50 println!("Total samples: {}\n", undersampled_data.nrows());
51
52 // Demonstrate SMOTE-like synthetic sample generation
53 println!("=== Synthetic Sample Generation (SMOTE-like) ==");
54
55 // Generate 4 synthetic samples for class 0 (minority class)
56 let (synthetic_data, synthetic_targets) =
57 generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
58
59 println!(
60 "Generated {} synthetic samples for class 0",
61 synthetic_data.nrows()
62 );
63 println!("Synthetic samples (first 3 features of each):");
64 for i in 0..synthetic_data.nrows() {
65 println!(
66 " Sample {}: [{:.3}, {:.3}] -> class {}",
67 i,
68 synthetic_data[[i, 0]],
69 synthetic_data[[i, 1]],
70 synthetic_targets[i]
71 );
72 }
73 println!();
74
75 // Demonstrate unified balancing strategies
76 println!("=== Unified Balancing Strategies ==============");
77
78 // Strategy 1: Random Oversampling
79 let (balanced_over, targets_over) = create_balanced_dataset(
80 &data,
81 &targets,
82 BalancingStrategy::RandomOversample,
83 Some(42),
84 )
85 .unwrap();
86
87 println!("Strategy: Random Oversampling");
88 print_class_distribution(&targets_over);
89 println!("Total samples: {}", balanced_over.nrows());
90
91 // Strategy 2: Random Undersampling
92 let (balanced_under, targets_under) = create_balanced_dataset(
93 &data,
94 &targets,
95 BalancingStrategy::RandomUndersample,
96 Some(42),
97 )
98 .unwrap();
99
100 println!("\nStrategy: Random Undersampling");
101 print_class_distribution(&targets_under);
102 println!("Total samples: {}", balanced_under.nrows());
103
104 // Strategy 3: SMOTE with k=1 neighbors
105 let (balanced_smote, targets_smote) = create_balanced_dataset(
106 &data,
107 &targets,
108 BalancingStrategy::SMOTE { k_neighbors: 1 },
109 Some(42),
110 )
111 .unwrap();
112
113 println!("\nStrategy: SMOTE (k_neighbors=1)");
114 print_class_distribution(&targets_smote);
115 println!("Total samples: {}", balanced_smote.nrows());
116
117 // Demonstrate with real-world dataset
118 println!("\n=== Real-world Example: Iris Dataset ==========");
119
120 let iris = load_iris().unwrap();
121 if let Some(iris_targets) = &iris.target {
122 println!("Original Iris dataset:");
123 print_class_distribution(iris_targets);
124
125 // Apply oversampling to iris (it's already balanced, but for demonstration)
126 let (iris_balanced, iris_balanced_targets) =
127 random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
128
129 println!("\nIris after oversampling (should remain the same):");
130 print_class_distribution(&iris_balanced_targets);
131 println!("Total samples: {}", iris_balanced.nrows());
132
133 // Create artificial imbalance by removing some samples
134 let indices_to_keep: Vec<usize> = (0..150)
135 .filter(|&i| {
136 let class = iris_targets[i].round() as i64;
137 // Keep all of class 0, 30 of class 1, 10 of class 2
138 match class {
139 0 => true, // Keep all 50
140 1 => i < 80, // Keep first 30 (indices 50-79)
141 2 => i < 110, // Keep first 10 (indices 100-109)
142 _ => false,
143 }
144 })
145 .collect();
146
147 let imbalanced_data = iris.data.select(ndarray::Axis(0), &indices_to_keep);
148 let imbalanced_targets = iris_targets.select(ndarray::Axis(0), &indices_to_keep);
149
150 println!("\nArtificially imbalanced Iris:");
151 print_class_distribution(&imbalanced_targets);
152
153 // Balance it using SMOTE
154 let (rebalanced_data, rebalanced_targets) = create_balanced_dataset(
155 &imbalanced_data,
156 &imbalanced_targets,
157 BalancingStrategy::SMOTE { k_neighbors: 3 },
158 Some(42),
159 )
160 .unwrap();
161
162 println!("\nAfter SMOTE rebalancing:");
163 print_class_distribution(&rebalanced_targets);
164 println!("Total samples: {}", rebalanced_data.nrows());
165 }
166
167 println!("\n=== Performance Comparison ====================");
168
169 // Show the tradeoffs between different strategies
170 println!("Strategy Comparison Summary:");
171 println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
172 println!("│ Strategy │ Final Size │ Characteristics │");
173 println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
174 println!(
175 "│ Random Oversample │ {} samples │ Increases data size, duplicates │",
176 balanced_over.nrows()
177 );
178 println!(
179 "│ Random Undersample │ {} samples │ Reduces data size, loses info │",
180 balanced_under.nrows()
181 );
182 println!(
183 "│ SMOTE │ {} samples │ Increases size, synthetic data │",
184 balanced_smote.nrows()
185 );
186 println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
187
188 println!("\n=== Balancing Demo Complete ====================");
189}