Function random_oversample

Source
pub fn random_oversample(
    data: &Array2<f64>,
    targets: &Array1<f64>,
    random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)>
Expand description

Performs random oversampling to balance class distribution

Duplicates samples from minority classes to match the majority class size. This is useful for handling imbalanced datasets in classification problems.

§Arguments

  • data - Feature matrix (n_samples, n_features)
  • targets - Target values for each sample
  • random_seed - Optional random seed for reproducible sampling

§Returns

A tuple containing the resampled (data, targets) arrays

§Examples

use ndarray::{Array1, Array2};
use scirs2_datasets::utils::random_oversample;

let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
let (balanced_data, balanced_targets) = random_oversample(&data, &targets, Some(42)).unwrap();
// Now both classes have 4 samples each
Examples found in repository?
examples/balancing_demo.rs (line 37)
12fn main() {
13    println!("=== Data Balancing Utilities Demonstration ===\n");
14
15    // Create an artificially imbalanced dataset for demonstration
16    let data = Array2::from_shape_vec(
17        (10, 2),
18        vec![
19            // Class 0 (minority): 2 samples
20            1.0, 1.0, 1.2, 1.1, // Class 1 (majority): 6 samples
21            5.0, 5.0, 5.1, 5.2, 4.9, 4.8, 5.3, 5.1, 4.8, 5.3, 5.0, 4.9,
22            // Class 2 (moderate): 2 samples
23            10.0, 10.0, 10.1, 9.9,
24        ],
25    )
26    .unwrap();
27
28    let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0]);
29
30    println!("Original imbalanced dataset:");
31    print_class_distribution(&targets);
32    println!("Total samples: {}\n", data.nrows());
33
34    // Demonstrate random oversampling
35    println!("=== Random Oversampling =======================");
36    let (oversampled_data, oversampled_targets) =
37        random_oversample(&data, &targets, Some(42)).unwrap();
38
39    println!("After random oversampling:");
40    print_class_distribution(&oversampled_targets);
41    println!("Total samples: {}\n", oversampled_data.nrows());
42
43    // Demonstrate random undersampling
44    println!("=== Random Undersampling ======================");
45    let (undersampled_data, undersampled_targets) =
46        random_undersample(&data, &targets, Some(42)).unwrap();
47
48    println!("After random undersampling:");
49    print_class_distribution(&undersampled_targets);
50    println!("Total samples: {}\n", undersampled_data.nrows());
51
52    // Demonstrate SMOTE-like synthetic sample generation
53    println!("=== Synthetic Sample Generation (SMOTE-like) ==");
54
55    // Generate 4 synthetic samples for class 0 (minority class)
56    let (synthetic_data, synthetic_targets) =
57        generate_synthetic_samples(&data, &targets, 0.0, 4, 1, Some(42)).unwrap();
58
59    println!(
60        "Generated {} synthetic samples for class 0",
61        synthetic_data.nrows()
62    );
63    println!("Synthetic samples (first 3 features of each):");
64    for i in 0..synthetic_data.nrows() {
65        println!(
66            "  Sample {}: [{:.3}, {:.3}] -> class {}",
67            i,
68            synthetic_data[[i, 0]],
69            synthetic_data[[i, 1]],
70            synthetic_targets[i]
71        );
72    }
73    println!();
74
75    // Demonstrate unified balancing strategies
76    println!("=== Unified Balancing Strategies ==============");
77
78    // Strategy 1: Random Oversampling
79    let (balanced_over, targets_over) = create_balanced_dataset(
80        &data,
81        &targets,
82        BalancingStrategy::RandomOversample,
83        Some(42),
84    )
85    .unwrap();
86
87    println!("Strategy: Random Oversampling");
88    print_class_distribution(&targets_over);
89    println!("Total samples: {}", balanced_over.nrows());
90
91    // Strategy 2: Random Undersampling
92    let (balanced_under, targets_under) = create_balanced_dataset(
93        &data,
94        &targets,
95        BalancingStrategy::RandomUndersample,
96        Some(42),
97    )
98    .unwrap();
99
100    println!("\nStrategy: Random Undersampling");
101    print_class_distribution(&targets_under);
102    println!("Total samples: {}", balanced_under.nrows());
103
104    // Strategy 3: SMOTE with k=1 neighbors
105    let (balanced_smote, targets_smote) = create_balanced_dataset(
106        &data,
107        &targets,
108        BalancingStrategy::SMOTE { k_neighbors: 1 },
109        Some(42),
110    )
111    .unwrap();
112
113    println!("\nStrategy: SMOTE (k_neighbors=1)");
114    print_class_distribution(&targets_smote);
115    println!("Total samples: {}", balanced_smote.nrows());
116
117    // Demonstrate with real-world dataset
118    println!("\n=== Real-world Example: Iris Dataset ==========");
119
120    let iris = load_iris().unwrap();
121    if let Some(iris_targets) = &iris.target {
122        println!("Original Iris dataset:");
123        print_class_distribution(iris_targets);
124
125        // Apply oversampling to iris (it's already balanced, but for demonstration)
126        let (iris_balanced, iris_balanced_targets) =
127            random_oversample(&iris.data, iris_targets, Some(42)).unwrap();
128
129        println!("\nIris after oversampling (should remain the same):");
130        print_class_distribution(&iris_balanced_targets);
131        println!("Total samples: {}", iris_balanced.nrows());
132
133        // Create artificial imbalance by removing some samples
134        let indices_to_keep: Vec<usize> = (0..150)
135            .filter(|&i| {
136                let class = iris_targets[i].round() as i64;
137                // Keep all of class 0, 30 of class 1, 10 of class 2
138                match class {
139                    0 => true,    // Keep all 50
140                    1 => i < 80,  // Keep first 30 (indices 50-79)
141                    2 => i < 110, // Keep first 10 (indices 100-109)
142                    _ => false,
143                }
144            })
145            .collect();
146
147        let imbalanced_data = iris.data.select(ndarray::Axis(0), &indices_to_keep);
148        let imbalanced_targets = iris_targets.select(ndarray::Axis(0), &indices_to_keep);
149
150        println!("\nArtificially imbalanced Iris:");
151        print_class_distribution(&imbalanced_targets);
152
153        // Balance it using SMOTE
154        let (rebalanced_data, rebalanced_targets) = create_balanced_dataset(
155            &imbalanced_data,
156            &imbalanced_targets,
157            BalancingStrategy::SMOTE { k_neighbors: 3 },
158            Some(42),
159        )
160        .unwrap();
161
162        println!("\nAfter SMOTE rebalancing:");
163        print_class_distribution(&rebalanced_targets);
164        println!("Total samples: {}", rebalanced_data.nrows());
165    }
166
167    println!("\n=== Performance Comparison ====================");
168
169    // Show the tradeoffs between different strategies
170    println!("Strategy Comparison Summary:");
171    println!("┌─────────────────────┬──────────────┬─────────────────────────────────┐");
172    println!("│ Strategy            │ Final Size   │ Characteristics                 │");
173    println!("├─────────────────────┼──────────────┼─────────────────────────────────┤");
174    println!(
175        "│ Random Oversample   │ {} samples   │ Increases data size, duplicates │",
176        balanced_over.nrows()
177    );
178    println!(
179        "│ Random Undersample  │ {} samples    │ Reduces data size, loses info   │",
180        balanced_under.nrows()
181    );
182    println!(
183        "│ SMOTE               │ {} samples   │ Increases size, synthetic data  │",
184        balanced_smote.nrows()
185    );
186    println!("└─────────────────────┴──────────────┴─────────────────────────────────┘");
187
188    println!("\n=== Balancing Demo Complete ====================");
189}