scirs2_datasets/utils/
balancing.rs

1//! Data balancing utilities for handling imbalanced datasets
2//!
3//! This module provides various strategies for balancing datasets to handle
4//! class imbalance problems in machine learning. It includes random oversampling,
5//! random undersampling, and SMOTE-like synthetic sample generation.
6
7use crate::error::{DatasetsError, Result};
8use ndarray::{Array1, Array2};
9use rand::prelude::*;
10use rand::rng;
11use rand::rngs::StdRng;
12use std::collections::HashMap;
13
14/// Balancing strategies for handling imbalanced datasets
15#[derive(Debug, Clone, Copy)]
16pub enum BalancingStrategy {
17    /// Random oversampling - duplicates minority class samples
18    RandomOversample,
19    /// Random undersampling - removes majority class samples
20    RandomUndersample,
21    /// SMOTE (Synthetic Minority Oversampling Technique) with specified k_neighbors
22    SMOTE {
23        /// Number of nearest neighbors to consider for synthetic sample generation
24        k_neighbors: usize,
25    },
26}
27
28/// Performs random oversampling to balance class distribution
29///
30/// Duplicates samples from minority classes to match the majority class size.
31/// This is useful for handling imbalanced datasets in classification problems.
32///
33/// # Arguments
34///
35/// * `data` - Feature matrix (n_samples, n_features)
36/// * `targets` - Target values for each sample
37/// * `random_seed` - Optional random seed for reproducible sampling
38///
39/// # Returns
40///
41/// A tuple containing the resampled (data, targets) arrays
42///
43/// # Examples
44///
45/// ```rust
46/// use ndarray::{Array1, Array2};
47/// use scirs2_datasets::utils::random_oversample;
48///
49/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
50/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
51/// let (balanced_data, balanced_targets) = random_oversample(&data, &targets, Some(42)).unwrap();
52/// // Now both classes have 4 samples each
53/// ```
54pub fn random_oversample(
55    data: &Array2<f64>,
56    targets: &Array1<f64>,
57    random_seed: Option<u64>,
58) -> Result<(Array2<f64>, Array1<f64>)> {
59    if data.nrows() != targets.len() {
60        return Err(DatasetsError::InvalidFormat(
61            "Data rows and targets length must match".to_string(),
62        ));
63    }
64
65    if data.is_empty() || targets.is_empty() {
66        return Err(DatasetsError::InvalidFormat(
67            "Data and targets cannot be empty".to_string(),
68        ));
69    }
70
71    // Group indices by class
72    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
73    for (i, &target) in targets.iter().enumerate() {
74        let class = target.round() as i64;
75        class_indices.entry(class).or_default().push(i);
76    }
77
78    // Find the majority class size
79    let max_class_size = class_indices.values().map(|v| v.len()).max().unwrap();
80
81    let mut rng = match random_seed {
82        Some(seed) => StdRng::seed_from_u64(seed),
83        None => {
84            let mut r = rng();
85            StdRng::seed_from_u64(r.next_u64())
86        }
87    };
88
89    // Collect all resampled indices
90    let mut resampled_indices = Vec::new();
91
92    for (_, indices) in class_indices {
93        let class_size = indices.len();
94
95        // Add all original samples
96        resampled_indices.extend(&indices);
97
98        // Oversample if this class is smaller than the majority class
99        if class_size < max_class_size {
100            let samples_needed = max_class_size - class_size;
101            for _ in 0..samples_needed {
102                let random_idx = rng.random_range(0..class_size);
103                resampled_indices.push(indices[random_idx]);
104            }
105        }
106    }
107
108    // Create resampled data and targets
109    let resampled_data = data.select(ndarray::Axis(0), &resampled_indices);
110    let resampled_targets = targets.select(ndarray::Axis(0), &resampled_indices);
111
112    Ok((resampled_data, resampled_targets))
113}
114
115/// Performs random undersampling to balance class distribution
116///
117/// Randomly removes samples from majority classes to match the minority class size.
118/// This reduces the overall dataset size but maintains balance.
119///
120/// # Arguments
121///
122/// * `data` - Feature matrix (n_samples, n_features)
123/// * `targets` - Target values for each sample
124/// * `random_seed` - Optional random seed for reproducible sampling
125///
126/// # Returns
127///
128/// A tuple containing the undersampled (data, targets) arrays
129///
130/// # Examples
131///
132/// ```rust
133/// use ndarray::{Array1, Array2};
134/// use scirs2_datasets::utils::random_undersample;
135///
136/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
137/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
138/// let (balanced_data, balanced_targets) = random_undersample(&data, &targets, Some(42)).unwrap();
139/// // Now both classes have 2 samples each
140/// ```
141pub fn random_undersample(
142    data: &Array2<f64>,
143    targets: &Array1<f64>,
144    random_seed: Option<u64>,
145) -> Result<(Array2<f64>, Array1<f64>)> {
146    if data.nrows() != targets.len() {
147        return Err(DatasetsError::InvalidFormat(
148            "Data rows and targets length must match".to_string(),
149        ));
150    }
151
152    if data.is_empty() || targets.is_empty() {
153        return Err(DatasetsError::InvalidFormat(
154            "Data and targets cannot be empty".to_string(),
155        ));
156    }
157
158    // Group indices by class
159    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
160    for (i, &target) in targets.iter().enumerate() {
161        let class = target.round() as i64;
162        class_indices.entry(class).or_default().push(i);
163    }
164
165    // Find the minority class size
166    let min_class_size = class_indices.values().map(|v| v.len()).min().unwrap();
167
168    let mut rng = match random_seed {
169        Some(seed) => StdRng::seed_from_u64(seed),
170        None => {
171            let mut r = rng();
172            StdRng::seed_from_u64(r.next_u64())
173        }
174    };
175
176    // Collect undersampled indices
177    let mut undersampled_indices = Vec::new();
178
179    for (_, mut indices) in class_indices {
180        if indices.len() > min_class_size {
181            // Randomly sample down to minority class size
182            indices.shuffle(&mut rng);
183            undersampled_indices.extend(&indices[0..min_class_size]);
184        } else {
185            // Use all samples if already at or below minority class size
186            undersampled_indices.extend(&indices);
187        }
188    }
189
190    // Create undersampled data and targets
191    let undersampled_data = data.select(ndarray::Axis(0), &undersampled_indices);
192    let undersampled_targets = targets.select(ndarray::Axis(0), &undersampled_indices);
193
194    Ok((undersampled_data, undersampled_targets))
195}
196
197/// Generates synthetic samples using SMOTE-like interpolation
198///
199/// Creates synthetic samples by interpolating between existing samples within each class.
200/// This is useful for oversampling minority classes without simple duplication.
201///
202/// # Arguments
203///
204/// * `data` - Feature matrix (n_samples, n_features)
205/// * `targets` - Target values for each sample
206/// * `target_class` - The class to generate synthetic samples for
207/// * `n_synthetic` - Number of synthetic samples to generate
208/// * `k_neighbors` - Number of nearest neighbors to consider for interpolation
209/// * `random_seed` - Optional random seed for reproducible generation
210///
211/// # Returns
212///
213/// A tuple containing the synthetic (data, targets) arrays
214///
215/// # Examples
216///
217/// ```rust
218/// use ndarray::{Array1, Array2};
219/// use scirs2_datasets::utils::generate_synthetic_samples;
220///
221/// let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
222/// let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
223/// let (synthetic_data, synthetic_targets) = generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).unwrap();
224/// assert_eq!(synthetic_data.nrows(), 2);
225/// assert_eq!(synthetic_targets.len(), 2);
226/// ```
227pub fn generate_synthetic_samples(
228    data: &Array2<f64>,
229    targets: &Array1<f64>,
230    target_class: f64,
231    n_synthetic: usize,
232    k_neighbors: usize,
233    random_seed: Option<u64>,
234) -> Result<(Array2<f64>, Array1<f64>)> {
235    if data.nrows() != targets.len() {
236        return Err(DatasetsError::InvalidFormat(
237            "Data rows and targets length must match".to_string(),
238        ));
239    }
240
241    if n_synthetic == 0 {
242        return Err(DatasetsError::InvalidFormat(
243            "Number of synthetic samples must be > 0".to_string(),
244        ));
245    }
246
247    if k_neighbors == 0 {
248        return Err(DatasetsError::InvalidFormat(
249            "Number of neighbors must be > 0".to_string(),
250        ));
251    }
252
253    // Find samples belonging to the target class
254    let class_indices: Vec<usize> = targets
255        .iter()
256        .enumerate()
257        .filter(|(_, &target)| (target - target_class).abs() < 1e-10)
258        .map(|(i, _)| i)
259        .collect();
260
261    if class_indices.len() < 2 {
262        return Err(DatasetsError::InvalidFormat(
263            "Need at least 2 samples of the target class for synthetic generation".to_string(),
264        ));
265    }
266
267    if k_neighbors >= class_indices.len() {
268        return Err(DatasetsError::InvalidFormat(
269            "k_neighbors must be less than the number of samples in the target class".to_string(),
270        ));
271    }
272
273    let mut rng = match random_seed {
274        Some(seed) => StdRng::seed_from_u64(seed),
275        None => {
276            let mut r = rng();
277            StdRng::seed_from_u64(r.next_u64())
278        }
279    };
280
281    let n_features = data.ncols();
282    let mut synthetic_data = Array2::zeros((n_synthetic, n_features));
283    let synthetic_targets = Array1::from_elem(n_synthetic, target_class);
284
285    for i in 0..n_synthetic {
286        // Randomly select a sample from the target class
287        let base_idx = class_indices[rng.random_range(0..class_indices.len())];
288        let base_sample = data.row(base_idx);
289
290        // Find k nearest neighbors within the same class
291        let mut distances: Vec<(usize, f64)> = class_indices
292            .iter()
293            .filter(|&&idx| idx != base_idx)
294            .map(|&idx| {
295                let neighbor = data.row(idx);
296                let distance: f64 = base_sample
297                    .iter()
298                    .zip(neighbor.iter())
299                    .map(|(&a, &b)| (a - b).powi(2))
300                    .sum::<f64>()
301                    .sqrt();
302                (idx, distance)
303            })
304            .collect();
305
306        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
307        let k_nearest = &distances[0..k_neighbors.min(distances.len())];
308
309        // Select a random neighbor from the k nearest
310        let neighbor_idx = k_nearest[rng.random_range(0..k_nearest.len())].0;
311        let neighbor_sample = data.row(neighbor_idx);
312
313        // Generate synthetic sample by interpolation
314        let alpha = rng.random_range(0.0..1.0);
315        for (j, synthetic_feature) in synthetic_data.row_mut(i).iter_mut().enumerate() {
316            *synthetic_feature = base_sample[j] + alpha * (neighbor_sample[j] - base_sample[j]);
317        }
318    }
319
320    Ok((synthetic_data, synthetic_targets))
321}
322
323/// Creates a balanced dataset using the specified balancing strategy
324///
325/// Automatically balances the dataset by applying oversampling, undersampling,
326/// or synthetic sample generation based on the specified strategy.
327///
328/// # Arguments
329///
330/// * `data` - Feature matrix (n_samples, n_features)
331/// * `targets` - Target values for each sample
332/// * `strategy` - Balancing strategy to use
333/// * `random_seed` - Optional random seed for reproducible balancing
334///
335/// # Returns
336///
337/// A tuple containing the balanced (data, targets) arrays
338///
339/// # Examples
340///
341/// ```rust
342/// use ndarray::{Array1, Array2};
343/// use scirs2_datasets::utils::{create_balanced_dataset, BalancingStrategy};
344///
345/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
346/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
347/// let (balanced_data, balanced_targets) = create_balanced_dataset(&data, &targets, BalancingStrategy::RandomOversample, Some(42)).unwrap();
348/// ```
349pub fn create_balanced_dataset(
350    data: &Array2<f64>,
351    targets: &Array1<f64>,
352    strategy: BalancingStrategy,
353    random_seed: Option<u64>,
354) -> Result<(Array2<f64>, Array1<f64>)> {
355    match strategy {
356        BalancingStrategy::RandomOversample => random_oversample(data, targets, random_seed),
357        BalancingStrategy::RandomUndersample => random_undersample(data, targets, random_seed),
358        BalancingStrategy::SMOTE { k_neighbors } => {
359            // Apply SMOTE to minority classes
360            let mut class_counts: HashMap<i64, usize> = HashMap::new();
361            for &target in targets.iter() {
362                let class = target.round() as i64;
363                *class_counts.entry(class).or_default() += 1;
364            }
365
366            let max_count = *class_counts.values().max().unwrap();
367            let mut combined_data = data.clone();
368            let mut combined_targets = targets.clone();
369
370            for (&class, &count) in &class_counts {
371                if count < max_count {
372                    let samples_needed = max_count - count;
373                    let (synthetic_data, synthetic_targets) = generate_synthetic_samples(
374                        data,
375                        targets,
376                        class as f64,
377                        samples_needed,
378                        k_neighbors,
379                        random_seed,
380                    )?;
381
382                    // Concatenate with existing data
383                    combined_data = ndarray::concatenate(
384                        ndarray::Axis(0),
385                        &[combined_data.view(), synthetic_data.view()],
386                    )
387                    .map_err(|_| {
388                        DatasetsError::InvalidFormat("Failed to concatenate data".to_string())
389                    })?;
390
391                    combined_targets = ndarray::concatenate(
392                        ndarray::Axis(0),
393                        &[combined_targets.view(), synthetic_targets.view()],
394                    )
395                    .map_err(|_| {
396                        DatasetsError::InvalidFormat("Failed to concatenate targets".to_string())
397                    })?;
398                }
399            }
400
401            Ok((combined_data, combined_targets))
402        }
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_random_oversample() {
412        let data = Array2::from_shape_vec(
413            (6, 2),
414            vec![
415                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
416            ],
417        )
418        .unwrap();
419        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
420
421        let (balanced_data, balanced_targets) =
422            random_oversample(&data, &targets, Some(42)).unwrap();
423
424        // Check that we now have equal number of each class
425        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
426        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
427        assert_eq!(class_0_count, 4); // Should be oversampled to match majority class
428        assert_eq!(class_1_count, 4);
429
430        // Check that total samples increased
431        assert_eq!(balanced_data.nrows(), 8);
432        assert_eq!(balanced_targets.len(), 8);
433
434        // Check that data dimensions are preserved
435        assert_eq!(balanced_data.ncols(), 2);
436    }
437
438    #[test]
439    fn test_random_oversample_invalid_params() {
440        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
441        let targets = Array1::from(vec![0.0, 1.0]);
442
443        // Mismatched data and targets
444        assert!(random_oversample(&data, &targets, None).is_err());
445
446        // Empty data
447        let empty_data = Array2::zeros((0, 2));
448        let empty_targets = Array1::from(vec![]);
449        assert!(random_oversample(&empty_data, &empty_targets, None).is_err());
450    }
451
452    #[test]
453    fn test_random_undersample() {
454        let data = Array2::from_shape_vec(
455            (6, 2),
456            vec![
457                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
458            ],
459        )
460        .unwrap();
461        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
462
463        let (balanced_data, balanced_targets) =
464            random_undersample(&data, &targets, Some(42)).unwrap();
465
466        // Check that we now have equal number of each class (minimum)
467        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
468        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
469        assert_eq!(class_0_count, 2); // Should match minority class
470        assert_eq!(class_1_count, 2); // Should be undersampled to match minority class
471
472        // Check that total samples decreased
473        assert_eq!(balanced_data.nrows(), 4);
474        assert_eq!(balanced_targets.len(), 4);
475
476        // Check that data dimensions are preserved
477        assert_eq!(balanced_data.ncols(), 2);
478    }
479
480    #[test]
481    fn test_random_undersample_invalid_params() {
482        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
483        let targets = Array1::from(vec![0.0, 1.0]);
484
485        // Mismatched data and targets
486        assert!(random_undersample(&data, &targets, None).is_err());
487
488        // Empty data
489        let empty_data = Array2::zeros((0, 2));
490        let empty_targets = Array1::from(vec![]);
491        assert!(random_undersample(&empty_data, &empty_targets, None).is_err());
492    }
493
494    #[test]
495    fn test_generate_synthetic_samples() {
496        let data =
497            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
498        let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
499
500        let (synthetic_data, synthetic_targets) =
501            generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).unwrap();
502
503        // Check that we generated the correct number of synthetic samples
504        assert_eq!(synthetic_data.nrows(), 2);
505        assert_eq!(synthetic_targets.len(), 2);
506
507        // Check that all synthetic targets are the correct class
508        for &target in synthetic_targets.iter() {
509            assert_eq!(target, 0.0);
510        }
511
512        // Check that data dimensions are preserved
513        assert_eq!(synthetic_data.ncols(), 2);
514
515        // Check that synthetic samples are interpolations (should be within reasonable bounds)
516        for i in 0..synthetic_data.nrows() {
517            for j in 0..synthetic_data.ncols() {
518                let value = synthetic_data[[i, j]];
519                assert!((0.5..=2.5).contains(&value)); // Should be within range of class 0 samples
520            }
521        }
522    }
523
524    #[test]
525    fn test_generate_synthetic_samples_invalid_params() {
526        let data =
527            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
528        let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
529
530        // Mismatched data and targets
531        let bad_targets = Array1::from(vec![0.0, 1.0]);
532        assert!(generate_synthetic_samples(&data, &bad_targets, 0.0, 2, 2, None).is_err());
533
534        // Zero synthetic samples
535        assert!(generate_synthetic_samples(&data, &targets, 0.0, 0, 2, None).is_err());
536
537        // Zero neighbors
538        assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 0, None).is_err());
539
540        // Too few samples of target class (only 1 sample of class 1.0)
541        assert!(generate_synthetic_samples(&data, &targets, 1.0, 2, 2, None).is_err());
542
543        // k_neighbors >= number of samples in class
544        assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 3, None).is_err());
545    }
546
547    #[test]
548    fn test_create_balanced_dataset_random_oversample() {
549        let data = Array2::from_shape_vec(
550            (6, 2),
551            vec![
552                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
553            ],
554        )
555        .unwrap();
556        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
557
558        let (balanced_data, balanced_targets) = create_balanced_dataset(
559            &data,
560            &targets,
561            BalancingStrategy::RandomOversample,
562            Some(42),
563        )
564        .unwrap();
565
566        // Check that classes are balanced
567        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
568        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
569        assert_eq!(class_0_count, class_1_count);
570        assert_eq!(balanced_data.nrows(), balanced_targets.len());
571    }
572
573    #[test]
574    fn test_create_balanced_dataset_random_undersample() {
575        let data = Array2::from_shape_vec(
576            (6, 2),
577            vec![
578                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
579            ],
580        )
581        .unwrap();
582        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
583
584        let (balanced_data, balanced_targets) = create_balanced_dataset(
585            &data,
586            &targets,
587            BalancingStrategy::RandomUndersample,
588            Some(42),
589        )
590        .unwrap();
591
592        // Check that classes are balanced
593        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
594        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
595        assert_eq!(class_0_count, class_1_count);
596        assert_eq!(balanced_data.nrows(), balanced_targets.len());
597    }
598
599    #[test]
600    fn test_create_balanced_dataset_smote() {
601        let data = Array2::from_shape_vec(
602            (8, 2),
603            vec![
604                1.0, 1.0, 1.5, 1.5, 2.0, 2.0, 2.5, 2.5, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0,
605            ],
606        )
607        .unwrap();
608        let targets = Array1::from(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Already balanced for easier testing
609
610        let (balanced_data, balanced_targets) = create_balanced_dataset(
611            &data,
612            &targets,
613            BalancingStrategy::SMOTE { k_neighbors: 2 },
614            Some(42),
615        )
616        .unwrap();
617
618        // Check that classes remain balanced
619        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
620        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
621        assert_eq!(class_0_count, class_1_count);
622        assert_eq!(balanced_data.nrows(), balanced_targets.len());
623    }
624
625    #[test]
626    fn test_balancing_strategy_with_multiple_classes() {
627        // Test with 3 classes of different sizes
628        let data = Array2::from_shape_vec((9, 2), (0..18).map(|x| x as f64).collect()).unwrap();
629        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
630        // Class distribution: 0 (2 samples), 1 (4 samples), 2 (3 samples)
631
632        // Test oversampling
633        let (_over_data, over_targets) = create_balanced_dataset(
634            &data,
635            &targets,
636            BalancingStrategy::RandomOversample,
637            Some(42),
638        )
639        .unwrap();
640
641        let over_class_0_count = over_targets.iter().filter(|&&x| x == 0.0).count();
642        let over_class_1_count = over_targets.iter().filter(|&&x| x == 1.0).count();
643        let over_class_2_count = over_targets.iter().filter(|&&x| x == 2.0).count();
644
645        // All classes should have 4 samples (majority class size)
646        assert_eq!(over_class_0_count, 4);
647        assert_eq!(over_class_1_count, 4);
648        assert_eq!(over_class_2_count, 4);
649
650        // Test undersampling
651        let (_under_data, under_targets) = create_balanced_dataset(
652            &data,
653            &targets,
654            BalancingStrategy::RandomUndersample,
655            Some(42),
656        )
657        .unwrap();
658
659        let under_class_0_count = under_targets.iter().filter(|&&x| x == 0.0).count();
660        let under_class_1_count = under_targets.iter().filter(|&&x| x == 1.0).count();
661        let under_class_2_count = under_targets.iter().filter(|&&x| x == 2.0).count();
662
663        // All classes should have 2 samples (minority class size)
664        assert_eq!(under_class_0_count, 2);
665        assert_eq!(under_class_1_count, 2);
666        assert_eq!(under_class_2_count, 2);
667    }
668}