random_undersample

Function random_undersample 

Source
pub fn random_undersample(
    data: &Array2<f64>,
    targets: &Array1<f64>,
    random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>
Expand description

Performs random undersampling to balance class distribution

Randomly removes samples from majority classes to match the minority class size. This reduces the overall dataset size but maintains balance.

§Arguments

  • data - Feature matrix (n_samples, n_features)
  • targets - Target values for each sample
  • random_seed - Optional random seed for reproducible sampling

§Returns

A tuple containing the undersampled (data, targets) arrays

§Examples

use scirs2_core::ndarray::{Array1, Array2};
use scirs2_datasets::utils::random_undersample;

let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
let (balanced_data, balanced_targets) = random_undersample(&data, &targets, Some(42)).unwrap();
// Now both classes have 2 samples each
Examples found in repository?
examples/balancing_demo.rs (line 47)
13fn main() {
14    println!("=== Data Balancing Utilities Demonstration ===\n");
15
16    // Create an artificially imbalanced dataset for demonstration
17    let data = Array2::from_shape_vec(
18        (10, 2),
19        vec![
20            // Class 0 (minority): 2 samples
21            1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
22            5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
23            // Class 2 (moderate): 2 samples
24            10.0, 10.0, 10.1, 9.9,
25        ],
26    )
27    .unwrap();
28
29    let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
30
31    println!("Original imbalanced dataset:");
32    print_class_distribution(&targets);
33    println!("Total samples: {}\n", data.nrows());
34
35    // Demonstrate random oversampling
36    println!("=== Random Oversampling =======================");
37    let (oversampleddata, oversampled_targets) =
38        random_oversample(&data, &targets, Some(42)).unwrap();
39
40    println!("After random oversampling:");
41    print_class_distribution(&oversampled_targets);
42    println!("Total samples: {}\n", oversampleddata.nrows());
43
44    // Demonstrate random undersampling
45    println!("=== Random Undersampling ======================");
46    let (undersampleddata, undersampled_targets) =
47        random_undersample(&data, &targets, Some(42)).unwrap();
48
49    println!("After random undersampling:");
50    print_class_distribution(&undersampled_targets);
51    println!("Total samples: {}\n", undersampleddata.nrows());
52
53    // Demonstrate SMOTE-like synthetic sample generation
54    println!("=== Synthetic Sample Generation (SMOTE-like) ==");
55
56    // Generate 4 synthetic samples for class 0 (minority class)
57    let (syntheticdata, synthetic_targets) =
58        generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
59
60    println!(
61        "Generated {} synthetic samples for class 0",
62        syntheticdata.nrows()
63    );
64    println!("Synthetic samples (first 3 features of each):");
65    for i in 0..syntheticdata.nrows() {
66        println!(
67            "  Sample {}: [{:.3}, {:.3}] -> class {}",
68            i,
69            syntheticdata[[i, 0]],
70            syntheticdata[[i, 1]],
71            synthetic_targets[i]
72        );
73    }
74    println!();
75
76    // Demonstrate unified balancing strategies
77    println!("=== Unified Balancing Strategies ==============");
78
79    // Strategy 1: Random Oversampling
80    let (balanced_over, targets_over) = create_balanced_dataset(
81        &data,
82        &targets,
83        BalancingStrategy::RandomOversample,
84        Some(42),
85    )
86    .unwrap();
87
88    println!("Strategy: Random Oversampling");
89    print_class_distribution(&targets_over);
90    println!("Total samples: {}", balanced_over.nrows());
91
92    // Strategy 2: Random Undersampling
93    let (balanced_under, targets_under) = create_balanced_dataset(
94        &data,
95        &targets,
96        BalancingStrategy::RandomUndersample,
97        Some(42),
98    )
99    .unwrap();
100
101    println!("\nStrategy: Random Undersampling");
102    print_class_distribution(&targets_under);
103    println!("Total samples: {}", balanced_under.nrows());
104
105    // Strategy 3: SMOTE with k=1 neighbors
106    let (balanced_smote, targets_smote) = create_balanced_dataset(
107        &data,
108        &targets,
109        BalancingStrategy::SMOTE { k_neighbors: 1 },
110        Some(42),
111    )
112    .unwrap();
113
114    println!("\nStrategy: SMOTE (k_neighbors=1)");
115    print_class_distribution(&targets_smote);
116    println!("Total samples: {}", balanced_smote.nrows());
117
118    // Demonstrate with real-world dataset
119    println!("\n=== Real-world Example: Iris Dataset ==========");
120
121    let iris = load_iris().unwrap();
122    if let Some(iris_targets) = &iris.target {
123        println!("Original Iris dataset:");
124        print_class_distribution(iris_targets);
125
126        // Apply oversampling to iris (it's already balanced, but for demonstration)
127        let (iris_balanced, iris_balanced_targets) =
128            random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
129
130        println!("\nIris after oversampling (should remain the same):");
131        print_class_distribution(&iris_balanced_targets);
132        println!("Total samples: {}", iris_balanced.nrows());
133
134        // Create artificial imbalance by removing some samples
135        let indices_to_keep: Vec<usize> = (0..150)
136            .filter(|&i| {
137                let class = iris_targets[i].round() as i64;
138                // Keep all of class 0, 30 of class 1, 10 of class 2
139                match class {
140                    0 => true,    // Keep all 50
141                    1 => i < 80,  // Keep first 30 (indices 50-79)
142                    2 => i < 110, // Keep first 10 (indices 100-109)
143                    _ => false,
144                }
145            })
146            .collect();
147
148        let imbalanceddata = iris
149            .data
150            .select(scirs2_core::ndarray::Axis(0), &indices_to_keep);
151        let imbalanced_targets =
152            iris_targets.select(scirs2_core::ndarray::Axis(0), &indices_to_keep);
153
154        println!("\nArtificially imbalanced Iris:");
155        print_class_distribution(&imbalanced_targets);
156
157        // Balance it using SMOTE
158        let (rebalanceddata, rebalanced_targets) = create_balanced_dataset(
159            &imbalanceddata,
160            &imbalanced_targets,
161            BalancingStrategy::SMOTE { k_neighbors: 3 },
162            Some(42),
163        )
164        .unwrap();
165
166        println!("\nAfter SMOTE rebalancing:");
167        print_class_distribution(&rebalanced_targets);
168        println!("Total samples: {}", rebalanceddata.nrows());
169    }
170
171    println!("\n=== Performance Comparison ====================");
172
173    // Show the tradeoffs between different strategies
174    println!("Strategy Comparison Summary:");
175    println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
176    println!("│ Strategy            │ Final Size   │ Characteristics                 │");
177    println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
178    println!(
179        "│ Random Oversample   │ {} samples   │ Increases data size, duplicates │",
180        balanced_over.nrows()
181    );
182    println!(
183        "│ Random Undersample  │ {} samples    │ Reduces data size, loses info   │",
184        balanced_under.nrows()
185    );
186    println!(
187        "│ SMOTE               │ {} samples   │ Increases size, synthetic data  │",
188        balanced_smote.nrows()
189    );
190    println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
191
192    println!("\n=== Balancing Demo Complete ====================");
193}