create_balanced_dataset

Function create_balanced_dataset 

Source
pub fn create_balanced_dataset(
    data: &Array2<f64>,
    targets: &Array1<f64>,
    strategy: BalancingStrategy,
    random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>
Expand description

Creates a balanced dataset using the specified balancing strategy

Automatically balances the dataset by applying oversampling, undersampling, or synthetic sample generation based on the specified strategy.

§Arguments

  • data - Feature matrix (n_samples, n_features)
  • targets - Target values for each sample
  • strategy - Balancing strategy to use
  • random_seed - Optional random seed for reproducible balancing

§Returns

A tuple containing the balanced (data, targets) arrays

§Examples

use scirs2_core::ndarray::{Array1, Array2};
use scirs2_datasets::utils::{create_balanced_dataset, BalancingStrategy};

let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
let (balanced_data, balanced_targets) = create_balanced_dataset(&data, &targets, BalancingStrategy::RandomOversample, Some(42)).unwrap();
Examples found in repository?
examples/balancing_demo.rs (lines 80-85)
13fn main() {
14    println!("=== Data Balancing Utilities Demonstration ===\n");
15
16    // Create an artificially imbalanced dataset for demonstration
17    let data = Array2::from_shape_vec(
18        (10, 2),
19        vec![
20            // Class 0 (minority): 2 samples
21            1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
22            5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
23            // Class 2 (moderate): 2 samples
24            10.0, 10.0, 10.1, 9.9,
25        ],
26    )
27    .unwrap();
28
29    let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
30
31    println!("Original imbalanced dataset:");
32    print_class_distribution(&targets);
33    println!("Total samples: {}\n", data.nrows());
34
35    // Demonstrate random oversampling
36    println!("=== Random Oversampling =======================");
37    let (oversampleddata, oversampled_targets) =
38        random_oversample(&data, &targets, Some(42)).unwrap();
39
40    println!("After random oversampling:");
41    print_class_distribution(&oversampled_targets);
42    println!("Total samples: {}\n", oversampleddata.nrows());
43
44    // Demonstrate random undersampling
45    println!("=== Random Undersampling ======================");
46    let (undersampleddata, undersampled_targets) =
47        random_undersample(&data, &targets, Some(42)).unwrap();
48
49    println!("After random undersampling:");
50    print_class_distribution(&undersampled_targets);
51    println!("Total samples: {}\n", undersampleddata.nrows());
52
53    // Demonstrate SMOTE-like synthetic sample generation
54    println!("=== Synthetic Sample Generation (SMOTE-like) ==");
55
56    // Generate 4 synthetic samples for class 0 (minority class)
57    let (syntheticdata, synthetic_targets) =
58        generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
59
60    println!(
61        "Generated {} synthetic samples for class 0",
62        syntheticdata.nrows()
63    );
64    println!("Synthetic samples (first 3 features of each):");
65    for i in 0..syntheticdata.nrows() {
66        println!(
67            "  Sample {}: [{:.3}, {:.3}] -> class {}",
68            i,
69            syntheticdata[[i, 0]],
70            syntheticdata[[i, 1]],
71            synthetic_targets[i]
72        );
73    }
74    println!();
75
76    // Demonstrate unified balancing strategies
77    println!("=== Unified Balancing Strategies ==============");
78
79    // Strategy 1: Random Oversampling
80    let (balanced_over, targets_over) = create_balanced_dataset(
81        &data,
82        &targets,
83        BalancingStrategy::RandomOversample,
84        Some(42),
85    )
86    .unwrap();
87
88    println!("Strategy: Random Oversampling");
89    print_class_distribution(&targets_over);
90    println!("Total samples: {}", balanced_over.nrows());
91
92    // Strategy 2: Random Undersampling
93    let (balanced_under, targets_under) = create_balanced_dataset(
94        &data,
95        &targets,
96        BalancingStrategy::RandomUndersample,
97        Some(42),
98    )
99    .unwrap();
100
101    println!("\nStrategy: Random Undersampling");
102    print_class_distribution(&targets_under);
103    println!("Total samples: {}", balanced_under.nrows());
104
105    // Strategy 3: SMOTE with k=1 neighbors
106    let (balanced_smote, targets_smote) = create_balanced_dataset(
107        &data,
108        &targets,
109        BalancingStrategy::SMOTE { k_neighbors: 1 },
110        Some(42),
111    )
112    .unwrap();
113
114    println!("\nStrategy: SMOTE (k_neighbors=1)");
115    print_class_distribution(&targets_smote);
116    println!("Total samples: {}", balanced_smote.nrows());
117
118    // Demonstrate with real-world dataset
119    println!("\n=== Real-world Example: Iris Dataset ==========");
120
121    let iris = load_iris().unwrap();
122    if let Some(iris_targets) = &iris.target {
123        println!("Original Iris dataset:");
124        print_class_distribution(iris_targets);
125
126        // Apply oversampling to iris (it's already balanced, but for demonstration)
127        let (iris_balanced, iris_balanced_targets) =
128            random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
129
130        println!("\nIris after oversampling (should remain the same):");
131        print_class_distribution(&iris_balanced_targets);
132        println!("Total samples: {}", iris_balanced.nrows());
133
134        // Create artificial imbalance by removing some samples
135        let indices_to_keep: Vec<usize> = (0..150)
136            .filter(|&i| {
137                let class = iris_targets[i].round() as i64;
138                // Keep all of class 0, 30 of class 1, 10 of class 2
139                match class {
140                    0 => true,    // Keep all 50
141                    1 => i < 80,  // Keep first 30 (indices 50-79)
142                    2 => i < 110, // Keep first 10 (indices 100-109)
143                    _ => false,
144                }
145            })
146            .collect();
147
148        let imbalanceddata = iris
149            .data
150            .select(scirs2_core::ndarray::Axis(0), &indices_to_keep);
151        let imbalanced_targets =
152            iris_targets.select(scirs2_core::ndarray::Axis(0), &indices_to_keep);
153
154        println!("\nArtificially imbalanced Iris:");
155        print_class_distribution(&imbalanced_targets);
156
157        // Balance it using SMOTE
158        let (rebalanceddata, rebalanced_targets) = create_balanced_dataset(
159            &imbalanceddata,
160            &imbalanced_targets,
161            BalancingStrategy::SMOTE { k_neighbors: 3 },
162            Some(42),
163        )
164        .unwrap();
165
166        println!("\nAfter SMOTE rebalancing:");
167        print_class_distribution(&rebalanced_targets);
168        println!("Total samples: {}", rebalanceddata.nrows());
169    }
170
171    println!("\n=== Performance Comparison ====================");
172
173    // Show the tradeoffs between different strategies
174    println!("Strategy Comparison Summary:");
175    println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
176    println!("│ Strategy            │ Final Size   │ Characteristics                 │");
177    println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
178    println!(
179        "│ Random Oversample   │ {} samples   │ Increases data size, duplicates │",
180        balanced_over.nrows()
181    );
182    println!(
183        "│ Random Undersample  │ {} samples    │ Reduces data size, loses info   │",
184        balanced_under.nrows()
185    );
186    println!(
187        "│ SMOTE               │ {} samples   │ Increases size, synthetic data  │",
188        balanced_smote.nrows()
189    );
190    println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
191
192    println!("\n=== Balancing Demo Complete ====================");
193}