Function create_balanced_dataset

Source
pub fn create_balanced_dataset(
    data: &Array2<f64>,
    targets: &Array1<f64>,
    strategy: BalancingStrategy,
    random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>
Expand description

Creates a balanced dataset using the specified balancing strategy

Automatically balances the dataset by applying oversampling, undersampling, or synthetic sample generation based on the specified strategy.

§Arguments

  • data - Feature matrix (n_samples, n_features)
  • targets - Target values for each sample
  • strategy - Balancing strategy to use
  • random_seed - Optional random seed for reproducible balancing

§Returns

A tuple containing the balanced (data, targets) arrays

§Examples

use ndarray::{Array1, Array2};
use scirs2_datasets::utils::{create_balanced_dataset, BalancingStrategy};

let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
let (balanced_data, balanced_targets) = create_balanced_dataset(&data, &targets, BalancingStrategy::RandomOversample, Some(42)).unwrap();
Examples found in repository?
examples/balancing_demo.rs (lines 79-84)
12fn main() {
13    println!("=== Data Balancing Utilities Demonstration ===\n");
14
15    // Create an artificially imbalanced dataset for demonstration
16    let data = Array2::from_shape_vec(
17        (10, 2),
18        vec![
19            // Class 0 (minority): 2 samples
20            1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
21            5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
22            // Class 2 (moderate): 2 samples
23            10.0, 10.0, 10.1, 9.9,
24        ],
25    )
26    .unwrap();
27
28    let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
29
30    println!("Original imbalanced dataset:");
31    print_class_distribution(&targets);
32    println!("Total samples: {}\n", data.nrows());
33
34    // Demonstrate random oversampling
35    println!("=== Random Oversampling =======================");
36    let (oversampled_data, oversampled_targets) =
37        random_oversample(&data, &targets, Some(42)).unwrap();
38
39    println!("After random oversampling:");
40    print_class_distribution(&oversampled_targets);
41    println!("Total samples: {}\n", oversampled_data.nrows());
42
43    // Demonstrate random undersampling
44    println!("=== Random Undersampling ======================");
45    let (undersampled_data, undersampled_targets) =
46        random_undersample(&data, &targets, Some(42)).unwrap();
47
48    println!("After random undersampling:");
49    print_class_distribution(&undersampled_targets);
50    println!("Total samples: {}\n", undersampled_data.nrows());
51
52    // Demonstrate SMOTE-like synthetic sample generation
53    println!("=== Synthetic Sample Generation (SMOTE-like) ==");
54
55    // Generate 4 synthetic samples for class 0 (minority class)
56    let (synthetic_data, synthetic_targets) =
57        generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
58
59    println!(
60        "Generated {} synthetic samples for class 0",
61        synthetic_data.nrows()
62    );
63    println!("Synthetic samples (first 3 features of each):");
64    for i in 0..synthetic_data.nrows() {
65        println!(
66            "  Sample {}: [{:.3}, {:.3}] -> class {}",
67            i,
68            synthetic_data[[i, 0]],
69            synthetic_data[[i, 1]],
70            synthetic_targets[i]
71        );
72    }
73    println!();
74
75    // Demonstrate unified balancing strategies
76    println!("=== Unified Balancing Strategies ==============");
77
78    // Strategy 1: Random Oversampling
79    let (balanced_over, targets_over) = create_balanced_dataset(
80        &data,
81        &targets,
82        BalancingStrategy::RandomOversample,
83        Some(42),
84    )
85    .unwrap();
86
87    println!("Strategy: Random Oversampling");
88    print_class_distribution(&targets_over);
89    println!("Total samples: {}", balanced_over.nrows());
90
91    // Strategy 2: Random Undersampling
92    let (balanced_under, targets_under) = create_balanced_dataset(
93        &data,
94        &targets,
95        BalancingStrategy::RandomUndersample,
96        Some(42),
97    )
98    .unwrap();
99
100    println!("\nStrategy: Random Undersampling");
101    print_class_distribution(&targets_under);
102    println!("Total samples: {}", balanced_under.nrows());
103
104    // Strategy 3: SMOTE with k=1 neighbors
105    let (balanced_smote, targets_smote) = create_balanced_dataset(
106        &data,
107        &targets,
108        BalancingStrategy::SMOTE { k_neighbors: 1 },
109        Some(42),
110    )
111    .unwrap();
112
113    println!("\nStrategy: SMOTE (k_neighbors=1)");
114    print_class_distribution(&targets_smote);
115    println!("Total samples: {}", balanced_smote.nrows());
116
117    // Demonstrate with real-world dataset
118    println!("\n=== Real-world Example: Iris Dataset ==========");
119
120    let iris = load_iris().unwrap();
121    if let Some(iris_targets) = &iris.target {
122        println!("Original Iris dataset:");
123        print_class_distribution(iris_targets);
124
125        // Apply oversampling to iris (it's already balanced, but for demonstration)
126        let (iris_balanced, iris_balanced_targets) =
127            random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
128
129        println!("\nIris after oversampling (should remain the same):");
130        print_class_distribution(&iris_balanced_targets);
131        println!("Total samples: {}", iris_balanced.nrows());
132
133        // Create artificial imbalance by removing some samples
134        let indices_to_keep: Vec<usize> = (0..150)
135            .filter(|&i| {
136                let class = iris_targets[i].round() as i64;
137                // Keep all of class 0, 30 of class 1, 10 of class 2
138                match class {
139                    0 => true,    // Keep all 50
140                    1 => i < 80,  // Keep first 30 (indices 50-79)
141                    2 => i < 110, // Keep first 10 (indices 100-109)
142                    _ => false,
143                }
144            })
145            .collect();
146
147        let imbalanced_data = iris.data.select(ndarray::Axis(0), &indices_to_keep);
148        let imbalanced_targets = iris_targets.select(ndarray::Axis(0), &indices_to_keep);
149
150        println!("\nArtificially imbalanced Iris:");
151        print_class_distribution(&imbalanced_targets);
152
153        // Balance it using SMOTE
154        let (rebalanced_data, rebalanced_targets) = create_balanced_dataset(
155            &imbalanced_data,
156            &imbalanced_targets,
157            BalancingStrategy::SMOTE { k_neighbors: 3 },
158            Some(42),
159        )
160        .unwrap();
161
162        println!("\nAfter SMOTE rebalancing:");
163        print_class_distribution(&rebalanced_targets);
164        println!("Total samples: {}", rebalanced_data.nrows());
165    }
166
167    println!("\n=== Performance Comparison ====================");
168
169    // Show the tradeoffs between different strategies
170    println!("Strategy Comparison Summary:");
171    println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
172    println!("│ Strategy            │ Final Size   │ Characteristics                 │");
173    println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
174    println!(
175        "│ Random Oversample   │ {} samples   │ Increases data size, duplicates │",
176        balanced_over.nrows()
177    );
178    println!(
179        "│ Random Undersample  │ {} samples    │ Reduces data size, loses info   │",
180        balanced_under.nrows()
181    );
182    println!(
183        "│ SMOTE               │ {} samples   │ Increases size, synthetic data  │",
184        balanced_smote.nrows()
185    );
186    println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
187
188    println!("\n=== Balancing Demo Complete ====================");
189}