scirs2_datasets/utils/
balancing.rs

1//! Data balancing utilities for handling imbalanced datasets
2//!
3//! This module provides various strategies for balancing datasets to handle
4//! class imbalance problems in machine learning. It includes random oversampling,
5//! random undersampling, and SMOTE-like synthetic sample generation.
6
7use crate::error::{DatasetsError, Result};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::prelude::*;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::rngs::StdRng;
12use scirs2_core::random::seq::SliceRandom;
13use scirs2_core::random::Uniform;
14use std::collections::HashMap;
15
16/// Balancing strategies for handling imbalanced datasets
17#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
18pub enum BalancingStrategy {
19    /// Random oversampling - duplicates minority class samples
20    RandomOversample,
21    /// Random undersampling - removes majority class samples
22    RandomUndersample,
23    /// SMOTE (Synthetic Minority Oversampling Technique) with specified k_neighbors
24    SMOTE {
25        /// Number of nearest neighbors to consider for synthetic sample generation
26        k_neighbors: usize,
27    },
28}
29
30/// Performs random oversampling to balance class distribution
31///
32/// Duplicates samples from minority classes to match the majority class size.
33/// This is useful for handling imbalanced datasets in classification problems.
34///
35/// # Arguments
36///
37/// * `data` - Feature matrix (n_samples, n_features)
38/// * `targets` - Target values for each sample
39/// * `random_seed` - Optional random seed for reproducible sampling
40///
41/// # Returns
42///
43/// A tuple containing the resampled (data, targets) arrays
44///
45/// # Examples
46///
47/// ```rust
48/// use scirs2_core::ndarray::{Array1, Array2};
49/// use scirs2_datasets::utils::random_oversample;
50///
51/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("Operation failed");
52/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
53/// let (balanced_data, balanced_targets) = random_oversample(&data, &targets, Some(42)).expect("Operation failed");
54/// // Now both classes have 4 samples each
55/// ```
56#[allow(dead_code)]
57pub fn random_oversample(
58    data: &Array2<f64>,
59    targets: &Array1<f64>,
60    random_seed: Option<u64>,
61) -> Result<(Array2<f64>, Array1<f64>)> {
62    if data.nrows() != targets.len() {
63        return Err(DatasetsError::InvalidFormat(
64            "Data rows and targets length must match".to_string(),
65        ));
66    }
67
68    if data.is_empty() || targets.is_empty() {
69        return Err(DatasetsError::InvalidFormat(
70            "Data and targets cannot be empty".to_string(),
71        ));
72    }
73
74    // Group indices by class
75    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
76    for (i, &target) in targets.iter().enumerate() {
77        let class = target.round() as i64;
78        class_indices.entry(class).or_default().push(i);
79    }
80
81    // Find the majority class size
82    let max_class_size = class_indices
83        .values()
84        .map(|v| v.len())
85        .max()
86        .expect("Operation failed");
87
88    let mut rng = match random_seed {
89        Some(_seed) => StdRng::seed_from_u64(_seed),
90        None => {
91            let mut r = thread_rng();
92            StdRng::seed_from_u64(r.next_u64())
93        }
94    };
95
96    // Collect all resampled indices
97    let mut resampled_indices = Vec::new();
98
99    for (_, indices) in class_indices {
100        let class_size = indices.len();
101
102        // Add all original samples
103        resampled_indices.extend(&indices);
104
105        // Oversample if this class is smaller than the majority class
106        if class_size < max_class_size {
107            let samples_needed = max_class_size - class_size;
108            for _ in 0..samples_needed {
109                let random_idx = rng.sample(Uniform::new(0, class_size).expect("Operation failed"));
110                resampled_indices.push(indices[random_idx]);
111            }
112        }
113    }
114
115    // Create resampled data and targets
116    let resampled_data = data.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
117    let resampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
118
119    Ok((resampled_data, resampled_targets))
120}
121
122/// Performs random undersampling to balance class distribution
123///
124/// Randomly removes samples from majority classes to match the minority class size.
125/// This reduces the overall dataset size but maintains balance.
126///
127/// # Arguments
128///
129/// * `data` - Feature matrix (n_samples, n_features)
130/// * `targets` - Target values for each sample
131/// * `random_seed` - Optional random seed for reproducible sampling
132///
133/// # Returns
134///
135/// A tuple containing the undersampled (data, targets) arrays
136///
137/// # Examples
138///
139/// ```rust
140/// use scirs2_core::ndarray::{Array1, Array2};
141/// use scirs2_datasets::utils::random_undersample;
142///
143/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("Operation failed");
144/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
145/// let (balanced_data, balanced_targets) = random_undersample(&data, &targets, Some(42)).expect("Operation failed");
146/// // Now both classes have 2 samples each
147/// ```
148#[allow(dead_code)]
149pub fn random_undersample(
150    data: &Array2<f64>,
151    targets: &Array1<f64>,
152    random_seed: Option<u64>,
153) -> Result<(Array2<f64>, Array1<f64>)> {
154    if data.nrows() != targets.len() {
155        return Err(DatasetsError::InvalidFormat(
156            "Data rows and targets length must match".to_string(),
157        ));
158    }
159
160    if data.is_empty() || targets.is_empty() {
161        return Err(DatasetsError::InvalidFormat(
162            "Data and targets cannot be empty".to_string(),
163        ));
164    }
165
166    // Group indices by class
167    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
168    for (i, &target) in targets.iter().enumerate() {
169        let class = target.round() as i64;
170        class_indices.entry(class).or_default().push(i);
171    }
172
173    // Find the minority class size
174    let min_class_size = class_indices
175        .values()
176        .map(|v| v.len())
177        .min()
178        .expect("Operation failed");
179
180    let mut rng = match random_seed {
181        Some(_seed) => StdRng::seed_from_u64(_seed),
182        None => {
183            let mut r = thread_rng();
184            StdRng::seed_from_u64(r.next_u64())
185        }
186    };
187
188    // Collect undersampled indices
189    let mut undersampled_indices = Vec::new();
190
191    for (_, mut indices) in class_indices {
192        if indices.len() > min_class_size {
193            // Randomly sample down to minority class size
194            indices.shuffle(&mut rng);
195            undersampled_indices.extend(&indices[0..min_class_size]);
196        } else {
197            // Use all samples if already at or below minority class size
198            undersampled_indices.extend(&indices);
199        }
200    }
201
202    // Create undersampled data and targets
203    let undersampled_data = data.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
204    let undersampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
205
206    Ok((undersampled_data, undersampled_targets))
207}
208
209/// Generates synthetic samples using SMOTE-like interpolation
210///
211/// Creates synthetic samples by interpolating between existing samples within each class.
212/// This is useful for oversampling minority classes without simple duplication.
213///
214/// # Arguments
215///
216/// * `data` - Feature matrix (n_samples, n_features)
217/// * `targets` - Target values for each sample
218/// * `target_class` - The class to generate synthetic samples for
219/// * `n_synthetic` - Number of synthetic samples to generate
220/// * `k_neighbors` - Number of nearest neighbors to consider for interpolation
221/// * `random_seed` - Optional random seed for reproducible generation
222///
223/// # Returns
224///
225/// A tuple containing the synthetic (data, targets) arrays
226///
227/// # Examples
228///
229/// ```rust
230/// use scirs2_core::ndarray::{Array1, Array2};
231/// use scirs2_datasets::utils::generate_synthetic_samples;
232///
233/// let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).expect("Operation failed");
234/// let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
235/// let (synthetic_data, synthetic_targets) = generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).expect("Operation failed");
236/// assert_eq!(synthetic_data.nrows(), 2);
237/// assert_eq!(synthetic_targets.len(), 2);
238/// ```
239#[allow(dead_code)]
240pub fn generate_synthetic_samples(
241    data: &Array2<f64>,
242    targets: &Array1<f64>,
243    target_class: f64,
244    n_synthetic: usize,
245    k_neighbors: usize,
246    random_seed: Option<u64>,
247) -> Result<(Array2<f64>, Array1<f64>)> {
248    if data.nrows() != targets.len() {
249        return Err(DatasetsError::InvalidFormat(
250            "Data rows and targets length must match".to_string(),
251        ));
252    }
253
254    if n_synthetic == 0 {
255        return Err(DatasetsError::InvalidFormat(
256            "Number of _synthetic samples must be > 0".to_string(),
257        ));
258    }
259
260    if k_neighbors == 0 {
261        return Err(DatasetsError::InvalidFormat(
262            "Number of _neighbors must be > 0".to_string(),
263        ));
264    }
265
266    // Find samples belonging to the target _class
267    let class_indices: Vec<usize> = targets
268        .iter()
269        .enumerate()
270        .filter(|(_, &target)| (target - target_class).abs() < 1e-10)
271        .map(|(i, _)| i)
272        .collect();
273
274    if class_indices.len() < 2 {
275        return Err(DatasetsError::InvalidFormat(
276            "Need at least 2 samples of the target _class for _synthetic generation".to_string(),
277        ));
278    }
279
280    if k_neighbors >= class_indices.len() {
281        return Err(DatasetsError::InvalidFormat(
282            "k_neighbors must be less than the number of samples in the target _class".to_string(),
283        ));
284    }
285
286    let mut rng = match random_seed {
287        Some(_seed) => StdRng::seed_from_u64(_seed),
288        None => {
289            let mut r = thread_rng();
290            StdRng::seed_from_u64(r.next_u64())
291        }
292    };
293
294    let n_features = data.ncols();
295    let mut synthetic_data = Array2::zeros((n_synthetic, n_features));
296    let synthetic_targets = Array1::from_elem(n_synthetic, target_class);
297
298    for i in 0..n_synthetic {
299        // Randomly select a sample from the target _class
300        let base_idx = class_indices
301            [rng.sample(Uniform::new(0, class_indices.len()).expect("Operation failed"))];
302        let base_sample = data.row(base_idx);
303
304        // Find k nearest _neighbors within the same _class
305        let mut distances: Vec<(usize, f64)> = class_indices
306            .iter()
307            .filter(|&&idx| idx != base_idx)
308            .map(|&idx| {
309                let neighbor = data.row(idx);
310                let distance: f64 = base_sample
311                    .iter()
312                    .zip(neighbor.iter())
313                    .map(|(&a, &b)| (a - b).powi(2))
314                    .sum::<f64>()
315                    .sqrt();
316                (idx, distance)
317            })
318            .collect();
319
320        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
321        let k_nearest = &distances[0..k_neighbors.min(distances.len())];
322
323        // Select a random neighbor from the k nearest
324        let neighbor_idx =
325            k_nearest[rng.sample(Uniform::new(0, k_nearest.len()).expect("Operation failed"))].0;
326        let neighbor_sample = data.row(neighbor_idx);
327
328        // Generate _synthetic sample by interpolation
329        let alpha = rng.gen_range(0.0..1.0);
330        for (j, synthetic_feature) in synthetic_data.row_mut(i).iter_mut().enumerate() {
331            *synthetic_feature = base_sample[j] + alpha * (neighbor_sample[j] - base_sample[j]);
332        }
333    }
334
335    Ok((synthetic_data, synthetic_targets))
336}
337
338/// Creates a balanced dataset using the specified balancing strategy
339///
340/// Automatically balances the dataset by applying oversampling, undersampling,
341/// or synthetic sample generation based on the specified strategy.
342///
343/// # Arguments
344///
345/// * `data` - Feature matrix (n_samples, n_features)
346/// * `targets` - Target values for each sample
347/// * `strategy` - Balancing strategy to use
348/// * `random_seed` - Optional random seed for reproducible balancing
349///
350/// # Returns
351///
352/// A tuple containing the balanced (data, targets) arrays
353///
354/// # Examples
355///
356/// ```rust
357/// use scirs2_core::ndarray::{Array1, Array2};
358/// use scirs2_datasets::utils::{create_balanced_dataset, BalancingStrategy};
359///
360/// let data = Array2::from_shape_vec((6, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("Operation failed");
361/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
362/// let (balanced_data, balanced_targets) = create_balanced_dataset(&data, &targets, BalancingStrategy::RandomOversample, Some(42)).expect("Operation failed");
363/// ```
364#[allow(dead_code)]
365pub fn create_balanced_dataset(
366    data: &Array2<f64>,
367    targets: &Array1<f64>,
368    strategy: BalancingStrategy,
369    random_seed: Option<u64>,
370) -> Result<(Array2<f64>, Array1<f64>)> {
371    match strategy {
372        BalancingStrategy::RandomOversample => random_oversample(data, targets, random_seed),
373        BalancingStrategy::RandomUndersample => random_undersample(data, targets, random_seed),
374        BalancingStrategy::SMOTE { k_neighbors } => {
375            // Apply SMOTE to minority classes
376            let mut class_counts: HashMap<i64, usize> = HashMap::new();
377            for &target in targets.iter() {
378                let class = target.round() as i64;
379                *class_counts.entry(class).or_default() += 1;
380            }
381
382            let max_count = *class_counts.values().max().expect("Operation failed");
383            let mut combined_data = data.clone();
384            let mut combined_targets = targets.clone();
385
386            for (&class, &count) in &class_counts {
387                if count < max_count {
388                    let samples_needed = max_count - count;
389                    let (synthetic_data, synthetic_targets) = generate_synthetic_samples(
390                        data,
391                        targets,
392                        class as f64,
393                        samples_needed,
394                        k_neighbors,
395                        random_seed,
396                    )?;
397
398                    // Concatenate with existing data
399                    combined_data = scirs2_core::ndarray::concatenate(
400                        scirs2_core::ndarray::Axis(0),
401                        &[combined_data.view(), synthetic_data.view()],
402                    )
403                    .map_err(|_| {
404                        DatasetsError::InvalidFormat("Failed to concatenate data".to_string())
405                    })?;
406
407                    combined_targets = scirs2_core::ndarray::concatenate(
408                        scirs2_core::ndarray::Axis(0),
409                        &[combined_targets.view(), synthetic_targets.view()],
410                    )
411                    .map_err(|_| {
412                        DatasetsError::InvalidFormat("Failed to concatenate targets".to_string())
413                    })?;
414                }
415            }
416
417            Ok((combined_data, combined_targets))
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use scirs2_core::random::Uniform;
426
427    #[test]
428    fn test_random_oversample() {
429        let data = Array2::from_shape_vec(
430            (6, 2),
431            vec![
432                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
433            ],
434        )
435        .expect("Test: SMOTE operation failed");
436        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
437
438        let (balanced_data, balanced_targets) =
439            random_oversample(&data, &targets, Some(42)).expect("Operation failed");
440
441        // Check that we now have equal number of each class
442        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
443        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
444        assert_eq!(class_0_count, 4); // Should be oversampled to match majority class
445        assert_eq!(class_1_count, 4);
446
447        // Check that total samples increased
448        assert_eq!(balanced_data.nrows(), 8);
449        assert_eq!(balanced_targets.len(), 8);
450
451        // Check that data dimensions are preserved
452        assert_eq!(balanced_data.ncols(), 2);
453    }
454
455    #[test]
456    fn test_random_oversample_invalid_params() {
457        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
458            .expect("Operation failed");
459        let targets = Array1::from(vec![0.0, 1.0]);
460
461        // Mismatched data and targets
462        assert!(random_oversample(&data, &targets, None).is_err());
463
464        // Empty data
465        let empty_data = Array2::zeros((0, 2));
466        let empty_targets = Array1::from(vec![]);
467        assert!(random_oversample(&empty_data, &empty_targets, None).is_err());
468    }
469
470    #[test]
471    fn test_random_undersample() {
472        let data = Array2::from_shape_vec(
473            (6, 2),
474            vec![
475                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
476            ],
477        )
478        .expect("Test: ADASYN operation failed");
479        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Imbalanced: 2 vs 4
480
481        let (balanced_data, balanced_targets) =
482            random_undersample(&data, &targets, Some(42)).expect("Operation failed");
483
484        // Check that we now have equal number of each class (minimum)
485        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
486        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
487        assert_eq!(class_0_count, 2); // Should match minority class
488        assert_eq!(class_1_count, 2); // Should be undersampled to match minority class
489
490        // Check that total samples decreased
491        assert_eq!(balanced_data.nrows(), 4);
492        assert_eq!(balanced_targets.len(), 4);
493
494        // Check that data dimensions are preserved
495        assert_eq!(balanced_data.ncols(), 2);
496    }
497
498    #[test]
499    fn test_random_undersample_invalid_params() {
500        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
501            .expect("Operation failed");
502        let targets = Array1::from(vec![0.0, 1.0]);
503
504        // Mismatched data and targets
505        assert!(random_undersample(&data, &targets, None).is_err());
506
507        // Empty data
508        let empty_data = Array2::zeros((0, 2));
509        let empty_targets = Array1::from(vec![]);
510        assert!(random_undersample(&empty_data, &empty_targets, None).is_err());
511    }
512
513    #[test]
514    fn test_generate_synthetic_samples() {
515        let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5])
516            .expect("Operation failed");
517        let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
518
519        let (synthetic_data, synthetic_targets) =
520            generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42))
521                .expect("Operation failed");
522
523        // Check that we generated the correct number of synthetic samples
524        assert_eq!(synthetic_data.nrows(), 2);
525        assert_eq!(synthetic_targets.len(), 2);
526
527        // Check that all synthetic targets are the correct class
528        for &target in synthetic_targets.iter() {
529            assert_eq!(target, 0.0);
530        }
531
532        // Check that data dimensions are preserved
533        assert_eq!(synthetic_data.ncols(), 2);
534
535        // Check that synthetic samples are interpolations (should be within reasonable bounds)
536        for i in 0..synthetic_data.nrows() {
537            for j in 0..synthetic_data.ncols() {
538                let value = synthetic_data[[i, j]];
539                assert!((0.5..=2.5).contains(&value)); // Should be within range of class 0 samples
540            }
541        }
542    }
543
544    #[test]
545    fn test_generate_synthetic_samples_invalid_params() {
546        let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5])
547            .expect("Operation failed");
548        let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
549
550        // Mismatched data and targets
551        let bad_targets = Array1::from(vec![0.0, 1.0]);
552        assert!(generate_synthetic_samples(&data, &bad_targets, 0.0, 2, 2, None).is_err());
553
554        // Zero synthetic samples
555        assert!(generate_synthetic_samples(&data, &targets, 0.0, 0, 2, None).is_err());
556
557        // Zero neighbors
558        assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 0, None).is_err());
559
560        // Too few samples of target class (only 1 sample of class 1.0)
561        assert!(generate_synthetic_samples(&data, &targets, 1.0, 2, 2, None).is_err());
562
563        // k_neighbors >= number of samples in class
564        assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 3, None).is_err());
565    }
566
567    #[test]
568    fn test_create_balanced_dataset_random_oversample() {
569        let data = Array2::from_shape_vec(
570            (6, 2),
571            vec![
572                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
573            ],
574        )
575        .expect("Test: undersample operation failed");
576        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
577
578        let (balanced_data, balanced_targets) = create_balanced_dataset(
579            &data,
580            &targets,
581            BalancingStrategy::RandomOversample,
582            Some(42),
583        )
584        .expect("Test: undersample operation failed");
585
586        // Check that classes are balanced
587        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
588        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
589        assert_eq!(class_0_count, class_1_count);
590        assert_eq!(balanced_data.nrows(), balanced_targets.len());
591    }
592
593    #[test]
594    fn test_create_balanced_dataset_random_undersample() {
595        let data = Array2::from_shape_vec(
596            (6, 2),
597            vec![
598                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
599            ],
600        )
601        .expect("Test: cluster centroids operation failed");
602        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
603
604        let (balanced_data, balanced_targets) = create_balanced_dataset(
605            &data,
606            &targets,
607            BalancingStrategy::RandomUndersample,
608            Some(42),
609        )
610        .expect("Test: cluster centroids operation failed");
611
612        // Check that classes are balanced
613        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
614        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
615        assert_eq!(class_0_count, class_1_count);
616        assert_eq!(balanced_data.nrows(), balanced_targets.len());
617    }
618
619    #[test]
620    fn test_create_balanced_dataset_smote() {
621        let data = Array2::from_shape_vec(
622            (8, 2),
623            vec![
624                1.0, 1.0, 1.5, 1.5, 2.0, 2.0, 2.5, 2.5, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0,
625            ],
626        )
627        .expect("Test: edited operation failed");
628        let targets = Array1::from(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); // Already balanced for easier testing
629
630        let (balanced_data, balanced_targets) = create_balanced_dataset(
631            &data,
632            &targets,
633            BalancingStrategy::SMOTE { k_neighbors: 2 },
634            Some(42),
635        )
636        .expect("Test: edited operation failed");
637
638        // Check that classes remain balanced
639        let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
640        let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
641        assert_eq!(class_0_count, class_1_count);
642        assert_eq!(balanced_data.nrows(), balanced_targets.len());
643    }
644
645    #[test]
646    fn test_balancing_strategy_with_multiple_classes() {
647        // Test with 3 classes of different sizes
648        let data = Array2::from_shape_vec((9, 2), (0..18).map(|x| x as f64).collect())
649            .expect("Operation failed");
650        let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
651        // Class distribution: 0 (2 samples), 1 (4 samples), 2 (3 samples)
652
653        // Test oversampling
654        let (_over_data, over_targets) = create_balanced_dataset(
655            &data,
656            &targets,
657            BalancingStrategy::RandomOversample,
658            Some(42),
659        )
660        .expect("Test: borderline SMOTE operation failed");
661
662        let over_class_0_count = over_targets.iter().filter(|&&x| x == 0.0).count();
663        let over_class_1_count = over_targets.iter().filter(|&&x| x == 1.0).count();
664        let over_class_2_count = over_targets.iter().filter(|&&x| x == 2.0).count();
665
666        // All classes should have 4 samples (majority class size)
667        assert_eq!(over_class_0_count, 4);
668        assert_eq!(over_class_1_count, 4);
669        assert_eq!(over_class_2_count, 4);
670
671        // Test undersampling
672        let (_under_data, under_targets) = create_balanced_dataset(
673            &data,
674            &targets,
675            BalancingStrategy::RandomUndersample,
676            Some(42),
677        )
678        .expect("Test: borderline SMOTE operation failed");
679
680        let under_class_0_count = under_targets.iter().filter(|&&x| x == 0.0).count();
681        let under_class_1_count = under_targets.iter().filter(|&&x| x == 1.0).count();
682        let under_class_2_count = under_targets.iter().filter(|&&x| x == 2.0).count();
683
684        // All classes should have 2 samples (minority class size)
685        assert_eq!(under_class_0_count, 2);
686        assert_eq!(under_class_1_count, 2);
687        assert_eq!(under_class_2_count, 2);
688    }
689}