random_undersample

Function random_undersample 

Source
pub fn random_undersample(
    data: &Array2<f64>,
    targets: &Array1<f64>,
    random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>
Expand description

Performs random undersampling to balance class distribution

Randomly removes samples from majority classes to match the minority class size. This reduces the overall dataset size but maintains balance.

§Arguments

  • data - Feature matrix (n_samples, n_features)
  • targets - Target values for each sample
  • random_seed - Optional random seed for reproducible sampling

§Returns

A tuple containing the undersampled (data, targets) arrays

§Examples

use ndarray::{Array1, Array2};
use scirs2__datasets::utils::random_undersample;

let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
let (balanced_data, balanced_targets) = random_undersample(&data, &targets, Some(42)).unwrap();
// Now both classes have 2 samples each
Examples found in repository?
examples/balancing_demo.rs (line 47)
13fn main() {
14    println!("=== Data Balancing Utilities Demonstration ===\n");
15
16    // Create an artificially imbalanced dataset for demonstration
17    let data = Array2::from_shape_vec(
18        (10, 2),
19        vec![
20            // Class 0 (minority): 2 samples
21            1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
22            5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
23            // Class 2 (moderate): 2 samples
24            10.0, 10.0, 10.1, 9.9,
25        ],
26    )
27    .unwrap();
28
29    let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
30
31    println!("Original imbalanced dataset:");
32    print_class_distribution(&targets);
33    println!("Total samples: {}\n", data.nrows());
34
35    // Demonstrate random oversampling
36    println!("=== Random Oversampling =======================");
37    let (oversampleddata, oversampled_targets) =
38        random_oversample(&data, &targets, Some(42)).unwrap();
39
40    println!("After random oversampling:");
41    print_class_distribution(&oversampled_targets);
42    println!("Total samples: {}\n", oversampleddata.nrows());
43
44    // Demonstrate random undersampling
45    println!("=== Random Undersampling ======================");
46    let (undersampleddata, undersampled_targets) =
47        random_undersample(&data, &targets, Some(42)).unwrap();
48
49    println!("After random undersampling:");
50    print_class_distribution(&undersampled_targets);
51    println!("Total samples: {}\n", undersampleddata.nrows());
52
53    // Demonstrate SMOTE-like synthetic sample generation
54    println!("=== Synthetic Sample Generation (SMOTE-like) ==");
55
56    // Generate 4 synthetic samples for class 0 (minority class)
57    let (syntheticdata, synthetic_targets) =
58        generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
59
60    println!(
61        "Generated {} synthetic samples for class 0",
62        syntheticdata.nrows()
63    );
64    println!("Synthetic samples (first 3 features of each):");
65    for i in 0..syntheticdata.nrows() {
66        println!(
67            "  Sample {}: [{:.3}, {:.3}] -> class {}",
68            i,
69            syntheticdata[[i, 0]],
70            syntheticdata[[i, 1]],
71            synthetic_targets[i]
72        );
73    }
74    println!();
75
76    // Demonstrate unified balancing strategies
77    println!("=== Unified Balancing Strategies ==============");
78
79    // Strategy 1: Random Oversampling
80    let (balanced_over, targets_over) = create_balanced_dataset(
81        &data,
82        &targets,
83        BalancingStrategy::RandomOversample,
84        Some(42),
85    )
86    .unwrap();
87
88    println!("Strategy: Random Oversampling");
89    print_class_distribution(&targets_over);
90    println!("Total samples: {}", balanced_over.nrows());
91
92    // Strategy 2: Random Undersampling
93    let (balanced_under, targets_under) = create_balanced_dataset(
94        &data,
95        &targets,
96        BalancingStrategy::RandomUndersample,
97        Some(42),
98    )
99    .unwrap();
100
101    println!("\nStrategy: Random Undersampling");
102    print_class_distribution(&targets_under);
103    println!("Total samples: {}", balanced_under.nrows());
104
105    // Strategy 3: SMOTE with k=1 neighbors
106    let (balanced_smote, targets_smote) = create_balanced_dataset(
107        &data,
108        &targets,
109        BalancingStrategy::SMOTE { k_neighbors: 1 },
110        Some(42),
111    )
112    .unwrap();
113
114    println!("\nStrategy: SMOTE (k_neighbors=1)");
115    print_class_distribution(&targets_smote);
116    println!("Total samples: {}", balanced_smote.nrows());
117
118    // Demonstrate with real-world dataset
119    println!("\n=== Real-world Example: Iris Dataset ==========");
120
121    let iris = load_iris().unwrap();
122    if let Some(iris_targets) = &iris.target {
123        println!("Original Iris dataset:");
124        print_class_distribution(iris_targets);
125
126        // Apply oversampling to iris (it's already balanced, but for demonstration)
127        let (iris_balanced, iris_balanced_targets) =
128            random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
129
130        println!("\nIris after oversampling (should remain the same):");
131        print_class_distribution(&iris_balanced_targets);
132        println!("Total samples: {}", iris_balanced.nrows());
133
134        // Create artificial imbalance by removing some samples
135        let indices_to_keep: Vec<usize> = (0..150)
136            .filter(|&i| {
137                let class = iris_targets[i].round() as i64;
138                // Keep all of class 0, 30 of class 1, 10 of class 2
139                match class {
140                    0 => true,    // Keep all 50
141                    1 => i < 80,  // Keep first 30 (indices 50-79)
142                    2 => i < 110, // Keep first 10 (indices 100-109)
143                    _ => false,
144                }
145            })
146            .collect();
147
148        let imbalanceddata = iris.data.select(ndarray::Axis(0), &indices_to_keep);
149        let imbalanced_targets = iris_targets.select(ndarray::Axis(0), &indices_to_keep);
150
151        println!("\nArtificially imbalanced Iris:");
152        print_class_distribution(&imbalanced_targets);
153
154        // Balance it using SMOTE
155        let (rebalanceddata, rebalanced_targets) = create_balanced_dataset(
156            &imbalanceddata,
157            &imbalanced_targets,
158            BalancingStrategy::SMOTE { k_neighbors: 3 },
159            Some(42),
160        )
161        .unwrap();
162
163        println!("\nAfter SMOTE rebalancing:");
164        print_class_distribution(&rebalanced_targets);
165        println!("Total samples: {}", rebalanceddata.nrows());
166    }
167
168    println!("\n=== Performance Comparison ====================");
169
170    // Show the tradeoffs between different strategies
171    println!("Strategy Comparison Summary:");
172    println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
173    println!("│ Strategy            │ Final Size   │ Characteristics                 │");
174    println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
175    println!(
176        "│ Random Oversample   │ {} samples   │ Increases data size, duplicates │",
177        balanced_over.nrows()
178    );
179    println!(
180        "│ Random Undersample  │ {} samples    │ Reduces data size, loses info   │",
181        balanced_under.nrows()
182    );
183    println!(
184        "│ SMOTE               │ {} samples   │ Increases size, synthetic data  │",
185        balanced_smote.nrows()
186    );
187    println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
188
189    println!("\n=== Balancing Demo Complete ====================");
190}