scirs2_datasets/utils/
balancing.rs

1//! Data balancing utilities for handling imbalanced datasets
2//!
3//! This module provides various strategies for balancing datasets to handle
4//! class imbalance problems in machine learning. It includes random oversampling,
5//! random undersampling, and SMOTE-like synthetic sample generation.
6
7use crate::error::{DatasetsError, Result};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::prelude::*;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::rngs::StdRng;
12use scirs2_core::random::seq::SliceRandom;
13use scirs2_core::random::Uniform;
14use std::collections::HashMap;
15
16/// Balancing strategies for handling imbalanced datasets
17#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
18pub enum BalancingStrategy {
19    /// Random oversampling - duplicates minority class samples
20    RandomOversample,
21    /// Random undersampling - removes majority class samples
22    RandomUndersample,
23    /// SMOTE (Synthetic Minority Oversampling Technique) with specified k_neighbors
24    SMOTE {
25        /// Number of nearest neighbors to consider for synthetic sample generation
26        k_neighbors: usize,
27    },
28}
29
30/// Performs random oversampling to balance class distribution
31///
32/// Duplicates samples from minority classes to match the majority class size.
33/// This is useful for handling imbalanced datasets in classification problems.
34///
35/// # Arguments
36///
37/// * `data` - Feature matrix (n_samples, n_features)
38/// * `targets` - Target values for each sample
39/// * `random_seed` - Optional random seed for reproducible sampling
40///
41/// # Returns
42///
43/// A tuple containing the resampled (data, targets) arrays
44///
45/// # Examples
46///
47/// ```rust
48/// use scirs2_core::ndarray::{Array1, Array2};
49/// use scirs2_datasets::utils::random_oversample;
50///
51/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
52/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
53/// let (balanced_data, balanced_targets) = random_oversample(&data, &targets, Some(42)).unwrap();
54/// // Now both classes have 4 samples each
55/// ```
56#[allow(dead_code)]
57pub fn random_oversample(
58    data: &Array2<f64>,
59    targets: &Array1<f64>,
60    random_seed: Option<u64>,
61) -> Result<(Array2<f64>, Array1<f64>)> {
62    if data.nrows() != targets.len() {
63        return Err(DatasetsError::InvalidFormat(
64            "Data rows and targets length must match".to_string(),
65        ));
66    }
67
68    if data.is_empty() || targets.is_empty() {
69        return Err(DatasetsError::InvalidFormat(
70            "Data and targets cannot be empty".to_string(),
71        ));
72    }
73
74    // Group indices by class
75    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
76    for (i, &target) in targets.iter().enumerate() {
77        let class = target.round() as i64;
78        class_indices.entry(class).or_default().push(i);
79    }
80
81    // Find the majority class size
82    let max_class_size = class_indices.values().map(|v| v.len()).max().unwrap();
83
84    let mut rng = match random_seed {
85        Some(_seed) => StdRng::seed_from_u64(_seed),
86        None => {
87            let mut r = thread_rng();
88            StdRng::seed_from_u64(r.next_u64())
89        }
90    };
91
92    // Collect all resampled indices
93    let mut resampled_indices = Vec::new();
94
95    for (_, indices) in class_indices {
96        let class_size = indices.len();
97
98        // Add all original samples
99        resampled_indices.extend(&indices);
100
101        // Oversample if this class is smaller than the majority class
102        if class_size < max_class_size {
103            let samples_needed = max_class_size - class_size;
104            for _ in 0..samples_needed {
105                let random_idx = rng.sample(Uniform::new(0, class_size).unwrap());
106                resampled_indices.push(indices[random_idx]);
107            }
108        }
109    }
110
111    // Create resampled data and targets
112    let resampled_data = data.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
113    let resampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
114
115    Ok((resampled_data, resampled_targets))
116}
117
118/// Performs random undersampling to balance class distribution
119///
120/// Randomly removes samples from majority classes to match the minority class size.
121/// This reduces the overall dataset size but maintains balance.
122///
123/// # Arguments
124///
125/// * `data` - Feature matrix (n_samples, n_features)
126/// * `targets` - Target values for each sample
127/// * `random_seed` - Optional random seed for reproducible sampling
128///
129/// # Returns
130///
131/// A tuple containing the undersampled (data, targets) arrays
132///
133/// # Examples
134///
135/// ```rust
136/// use scirs2_core::ndarray::{Array1, Array2};
137/// use scirs2_datasets::utils::random_undersample;
138///
139/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
140/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
141/// let (balanced_data, balanced_targets) = random_undersample(&data, &targets, Some(42)).unwrap();
142/// // Now both classes have 2 samples each
143/// ```
144#[allow(dead_code)]
145pub fn random_undersample(
146    data: &Array2<f64>,
147    targets: &Array1<f64>,
148    random_seed: Option<u64>,
149) -> Result<(Array2<f64>, Array1<f64>)> {
150    if data.nrows() != targets.len() {
151        return Err(DatasetsError::InvalidFormat(
152            "Data rows and targets length must match".to_string(),
153        ));
154    }
155
156    if data.is_empty() || targets.is_empty() {
157        return Err(DatasetsError::InvalidFormat(
158            "Data and targets cannot be empty".to_string(),
159        ));
160    }
161
162    // Group indices by class
163    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
164    for (i, &target) in targets.iter().enumerate() {
165        let class = target.round() as i64;
166        class_indices.entry(class).or_default().push(i);
167    }
168
169    // Find the minority class size
170    let min_class_size = class_indices.values().map(|v| v.len()).min().unwrap();
171
172    let mut rng = match random_seed {
173        Some(_seed) => StdRng::seed_from_u64(_seed),
174        None => {
175            let mut r = thread_rng();
176            StdRng::seed_from_u64(r.next_u64())
177        }
178    };
179
180    // Collect undersampled indices
181    let mut undersampled_indices = Vec::new();
182
183    for (_, mut indices) in class_indices {
184        if indices.len() > min_class_size {
185            // Randomly sample down to minority class size
186            indices.shuffle(&mut rng);
187            undersampled_indices.extend(&indices[0..min_class_size]);
188        } else {
189            // Use all samples if already at or below minority class size
190            undersampled_indices.extend(&indices);
191        }
192    }
193
194    // Create undersampled data and targets
195    let undersampled_data = data.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
196    let undersampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
197
198    Ok((undersampled_data, undersampled_targets))
199}
200
201/// Generates synthetic samples using SMOTE-like interpolation
202///
203/// Creates synthetic samples by interpolating between existing samples within each class.
204/// This is useful for oversampling minority classes without simple duplication.
205///
206/// # Arguments
207///
208/// * `data` - Feature matrix (n_samples, n_features)
209/// * `targets` - Target values for each sample
210/// * `target_class` - The class to generate synthetic samples for
211/// * `n_synthetic` - Number of synthetic samples to generate
212/// * `k_neighbors` - Number of nearest neighbors to consider for interpolation
213/// * `random_seed` - Optional random seed for reproducible generation
214///
215/// # Returns
216///
217/// A tuple containing the synthetic (data, targets) arrays
218///
219/// # Examples
220///
221/// ```rust
222/// use scirs2_core::ndarray::{Array1, Array2};
223/// use scirs2_datasets::utils::generate_synthetic_samples;
224///
225/// let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
226/// let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
227/// let (synthetic_data, synthetic_targets) = generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).unwrap();
228/// assert_eq!(synthetic_data.nrows(), 2);
229/// assert_eq!(synthetic_targets.len(), 2);
230/// ```
231#[allow(dead_code)]
232pub fn generate_synthetic_samples(
233    data: &Array2<f64>,
234    targets: &Array1<f64>,
235    target_class: f64,
236    n_synthetic: usize,
237    k_neighbors: usize,
238    random_seed: Option<u64>,
239) -> Result<(Array2<f64>, Array1<f64>)> {
240    if data.nrows() != targets.len() {
241        return Err(DatasetsError::InvalidFormat(
242            "Data rows and targets length must match".to_string(),
243        ));
244    }
245
246    if n_synthetic == 0 {
247        return Err(DatasetsError::InvalidFormat(
248            "Number of _synthetic samples must be > 0".to_string(),
249        ));
250    }
251
252    if k_neighbors == 0 {
253        return Err(DatasetsError::InvalidFormat(
254            "Number of _neighbors must be > 0".to_string(),
255        ));
256    }
257
258    // Find samples belonging to the target _class
259    let class_indices: Vec<usize> = targets
260        .iter()
261        .enumerate()
262        .filter(|(_, &target)| (target - target_class).abs() < 1e-10)
263        .map(|(i, _)| i)
264        .collect();
265
266    if class_indices.len() < 2 {
267        return Err(DatasetsError::InvalidFormat(
268            "Need at least 2 samples of the target _class for _synthetic generation".to_string(),
269        ));
270    }
271
272    if k_neighbors >= class_indices.len() {
273        return Err(DatasetsError::InvalidFormat(
274            "k_neighbors must be less than the number of samples in the target _class".to_string(),
275        ));
276    }
277
278    let mut rng = match random_seed {
279        Some(_seed) => StdRng::seed_from_u64(_seed),
280        None => {
281            let mut r = thread_rng();
282            StdRng::seed_from_u64(r.next_u64())
283        }
284    };
285
286    let n_features = data.ncols();
287    let mut synthetic_data = Array2::zeros((n_synthetic, n_features));
288    let synthetic_targets = Array1::from_elem(n_synthetic, target_class);
289
290    for i in 0..n_synthetic {
291        // Randomly select a sample from the target _class
292        let base_idx = class_indices[rng.sample(Uniform::new(0, class_indices.len()).unwrap())];
293        let base_sample = data.row(base_idx);
294
295        // Find k nearest _neighbors within the same _class
296        let mut distances: Vec<(usize, f64)> = class_indices
297            .iter()
298            .filter(|&&idx| idx != base_idx)
299            .map(|&idx| {
300                let neighbor = data.row(idx);
301                let distance: f64 = base_sample
302                    .iter()
303                    .zip(neighbor.iter())
304                    .map(|(&a, &b)| (a - b).powi(2))
305                    .sum::<f64>()
306                    .sqrt();
307                (idx, distance)
308            })
309            .collect();
310
311        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
312        let k_nearest = &distances[0..k_neighbors.min(distances.len())];
313
314        // Select a random neighbor from the k nearest
315        let neighbor_idx = k_nearest[rng.sample(Uniform::new(0, k_nearest.len()).unwrap())].0;
316        let neighbor_sample = data.row(neighbor_idx);
317
318        // Generate _synthetic sample by interpolation
319        let alpha = rng.gen_range(0.0..1.0);
320        for (j, synthetic_feature) in synthetic_data.row_mut(i).iter_mut().enumerate() {
321            *synthetic_feature = base_sample[j] + alpha * (neighbor_sample[j] - base_sample[j]);
322        }
323    }
324
325    Ok((synthetic_data, synthetic_targets))
326}
327
328/// Creates a balanced dataset using the specified balancing strategy
329///
330/// Automatically balances the dataset by applying oversampling, undersampling,
331/// or synthetic sample generation based on the specified strategy.
332///
333/// # Arguments
334///
335/// * `data` - Feature matrix (n_samples, n_features)
336/// * `targets` - Target values for each sample
337/// * `strategy` - Balancing strategy to use
338/// * `random_seed` - Optional random seed for reproducible balancing
339///
340/// # Returns
341///
342/// A tuple containing the balanced (data, targets) arrays
343///
344/// # Examples
345///
346/// ```rust
347/// use scirs2_core::ndarray::{Array1, Array2};
348/// use scirs2_datasets::utils::{create_balanced_dataset, BalancingStrategy};
349///
350/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
351/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
352/// let (balanced_data, balanced_targets) = create_balanced_dataset(&data, &targets, BalancingStrategy::RandomOversample, Some(42)).unwrap();
353/// ```
354#[allow(dead_code)]
355pub fn create_balanced_dataset(
356    data: &Array2<f64>,
357    targets: &Array1<f64>,
358    strategy: BalancingStrategy,
359    random_seed: Option<u64>,
360) -> Result<(Array2<f64>, Array1<f64>)> {
361    match strategy {
362        BalancingStrategy::RandomOversample => random_oversample(data, targets, random_seed),
363        BalancingStrategy::RandomUndersample => random_undersample(data, targets, random_seed),
364        BalancingStrategy::SMOTE { k_neighbors } => {
365            // Apply SMOTE to minority classes
366            let mut class_counts: HashMap<i64, usize> = HashMap::new();
367            for &target in targets.iter() {
368                let class = target.round() as i64;
369                *class_counts.entry(class).or_default() += 1;
370            }
371
372            let max_count = *class_counts.values().max().unwrap();
373            let mut combined_data = data.clone();
374            let mut combined_targets = targets.clone();
375
376            for (&class, &count) in &class_counts {
377                if count < max_count {
378                    let samples_needed = max_count - count;
379                    let (synthetic_data, synthetic_targets) = generate_synthetic_samples(
380                        data,
381                        targets,
382                        class as f64,
383                        samples_needed,
384                        k_neighbors,
385                        random_seed,
386                    )?;
387
388                    // Concatenate with existing data
389                    combined_data = scirs2_core::ndarray::concatenate(
390                        scirs2_core::ndarray::Axis(0),
391                        &[combined_data.view(), synthetic_data.view()],
392                    )
393                    .map_err(|_| {
394                        DatasetsError::InvalidFormat("Failed to concatenate data".to_string())
395                    })?;
396
397                    combined_targets = scirs2_core::ndarray::concatenate(
398                        scirs2_core::ndarray::Axis(0),
399                        &[combined_targets.view(), synthetic_targets.view()],
400                    )
401                    .map_err(|_| {
402                        DatasetsError::InvalidFormat("Failed to concatenate targets".to_string())
403                    })?;
404                }
405            }
406
407            Ok((combined_data, combined_targets))
408        }
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use scirs2_core::random::Uniform;
416
417    #[test]
418    fn test_random_oversample() {
419        let data = Array2::from_shape_vec(
420            (6, 2),
421            vec![
422                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
423            ],
424        )
425        .unwrap();
426        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
427
428        let (balanced_data, balanced_targets) =
429            random_oversample(&data, &targets, Some(42)).unwrap();
430
431        // Check that we now have equal number of each class
432        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
433        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
434        assert_eq!(class_0_count, 4); // Should be oversampled to match majority class
435        assert_eq!(class_1_count, 4);
436
437        // Check that total samples increased
438        assert_eq!(balanced_data.nrows(), 8);
439        assert_eq!(balanced_targets.len(), 8);
440
441        // Check that data dimensions are preserved
442        assert_eq!(balanced_data.ncols(), 2);
443    }
444
445    #[test]
446    fn test_random_oversample_invalid_params() {
447        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
448        let targets = Array1::from(vec![0.0, 1.0]);
449
450        // Mismatched data and targets
451        assert!(random_oversample(&data, &targets, None).is_err());
452
453        // Empty data
454        let empty_data = Array2::zeros((0, 2));
455        let empty_targets = Array1::from(vec![]);
456        assert!(random_oversample(&empty_data, &empty_targets, None).is_err());
457    }
458
459    #[test]
460    fn test_random_undersample() {
461        let data = Array2::from_shape_vec(
462            (6, 2),
463            vec![
464                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
465            ],
466        )
467        .unwrap();
468        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
469
470        let (balanced_data, balanced_targets) =
471            random_undersample(&data, &targets, Some(42)).unwrap();
472
473        // Check that we now have equal number of each class (minimum)
474        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
475        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
476        assert_eq!(class_0_count, 2); // Should match minority class
477        assert_eq!(class_1_count, 2); // Should be undersampled to match minority class
478
479        // Check that total samples decreased
480        assert_eq!(balanced_data.nrows(), 4);
481        assert_eq!(balanced_targets.len(), 4);
482
483        // Check that data dimensions are preserved
484        assert_eq!(balanced_data.ncols(), 2);
485    }
486
487    #[test]
488    fn test_random_undersample_invalid_params() {
489        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
490        let targets = Array1::from(vec![0.0, 1.0]);
491
492        // Mismatched data and targets
493        assert!(random_undersample(&data, &targets, None).is_err());
494
495        // Empty data
496        let empty_data = Array2::zeros((0, 2));
497        let empty_targets = Array1::from(vec![]);
498        assert!(random_undersample(&empty_data, &empty_targets, None).is_err());
499    }
500
501    #[test]
502    fn test_generate_synthetic_samples() {
503        let data =
504            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
505        let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
506
507        let (synthetic_data, synthetic_targets) =
508            generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).unwrap();
509
510        // Check that we generated the correct number of synthetic samples
511        assert_eq!(synthetic_data.nrows(), 2);
512        assert_eq!(synthetic_targets.len(), 2);
513
514        // Check that all synthetic targets are the correct class
515        for &target in synthetic_targets.iter() {
516            assert_eq!(target, 0.0);
517        }
518
519        // Check that data dimensions are preserved
520        assert_eq!(synthetic_data.ncols(), 2);
521
522        // Check that synthetic samples are interpolations (should be within reasonable bounds)
523        for i in 0..synthetic_data.nrows() {
524            for j in 0..synthetic_data.ncols() {
525                let value = synthetic_data[[i, j]];
526                assert!((0.5..=2.5).contains(&value)); // Should be within range of class 0 samples
527            }
528        }
529    }
530
531    #[test]
532    fn test_generate_synthetic_samples_invalid_params() {
533        let data =
534            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
535        let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
536
537        // Mismatched data and targets
538        let bad_targets = Array1::from(vec![0.0, 1.0]);
539        assert!(generate_synthetic_samples(&data, &bad_targets, 0.0, 2, 2, None).is_err());
540
541        // Zero synthetic samples
542        assert!(generate_synthetic_samples(&data, &targets, 0.0, 0, 2, None).is_err());
543
544        // Zero neighbors
545        assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 0, None).is_err());
546
547        // Too few samples of target class (only 1 sample of class 1.0)
548        assert!(generate_synthetic_samples(&data, &targets, 1.0, 2, 2, None).is_err());
549
550        // k_neighbors >= number of samples in class
551        assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 3, None).is_err());
552    }
553
554    #[test]
555    fn test_create_balanced_dataset_random_oversample() {
556        let data = Array2::from_shape_vec(
557            (6, 2),
558            vec![
559                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
560            ],
561        )
562        .unwrap();
563        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
564
565        let (balanced_data, balanced_targets) = create_balanced_dataset(
566            &data,
567            &targets,
568            BalancingStrategy::RandomOversample,
569            Some(42),
570        )
571        .unwrap();
572
573        // Check that classes are balanced
574        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
575        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
576        assert_eq!(class_0_count, class_1_count);
577        assert_eq!(balanced_data.nrows(), balanced_targets.len());
578    }
579
580    #[test]
581    fn test_create_balanced_dataset_random_undersample() {
582        let data = Array2::from_shape_vec(
583            (6, 2),
584            vec![
585                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
586            ],
587        )
588        .unwrap();
589        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
590
591        let (balanced_data, balanced_targets) = create_balanced_dataset(
592            &data,
593            &targets,
594            BalancingStrategy::RandomUndersample,
595            Some(42),
596        )
597        .unwrap();
598
599        // Check that classes are balanced
600        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
601        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
602        assert_eq!(class_0_count, class_1_count);
603        assert_eq!(balanced_data.nrows(), balanced_targets.len());
604    }
605
606    #[test]
607    fn test_create_balanced_dataset_smote() {
608        let data = Array2::from_shape_vec(
609            (8, 2),
610            vec![
611                1.0, 1.0, 1.5, 1.5, 2.0, 2.0, 2.5, 2.5, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0,
612            ],
613        )
614        .unwrap();
615        let targets = Array1::from(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Already balanced for easier testing
616
617        let (balanced_data, balanced_targets) = create_balanced_dataset(
618            &data,
619            &targets,
620            BalancingStrategy::SMOTE { k_neighbors: 2 },
621            Some(42),
622        )
623        .unwrap();
624
625        // Check that classes remain balanced
626        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
627        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
628        assert_eq!(class_0_count, class_1_count);
629        assert_eq!(balanced_data.nrows(), balanced_targets.len());
630    }
631
632    #[test]
633    fn test_balancing_strategy_with_multiple_classes() {
634        // Test with 3 classes of different sizes
635        let data = Array2::from_shape_vec((9, 2), (0..18).map(|x| x as f64).collect()).unwrap();
636        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
637        // Class distribution: 0 (2 samples), 1 (4 samples), 2 (3 samples)
638
639        // Test oversampling
640        let (_over_data, over_targets) = create_balanced_dataset(
641            &data,
642            &targets,
643            BalancingStrategy::RandomOversample,
644            Some(42),
645        )
646        .unwrap();
647
648        let over_class_0_count = over_targets.iter().filter(|&&x| x == 0.0).count();
649        let over_class_1_count = over_targets.iter().filter(|&&x| x == 1.0).count();
650        let over_class_2_count = over_targets.iter().filter(|&&x| x == 2.0).count();
651
652        // All classes should have 4 samples (majority class size)
653        assert_eq!(over_class_0_count, 4);
654        assert_eq!(over_class_1_count, 4);
655        assert_eq!(over_class_2_count, 4);
656
657        // Test undersampling
658        let (_under_data, under_targets) = create_balanced_dataset(
659            &data,
660            &targets,
661            BalancingStrategy::RandomUndersample,
662            Some(42),
663        )
664        .unwrap();
665
666        let under_class_0_count = under_targets.iter().filter(|&&x| x == 0.0).count();
667        let under_class_1_count = under_targets.iter().filter(|&&x| x == 1.0).count();
668        let under_class_2_count = under_targets.iter().filter(|&&x| x == 2.0).count();
669
670        // All classes should have 2 samples (minority class size)
671        assert_eq!(under_class_0_count, 2);
672        assert_eq!(under_class_1_count, 2);
673        assert_eq!(under_class_2_count, 2);
674    }
675}