1use crate::error::{DatasetsError, Result};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::prelude::*;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::rngs::StdRng;
12use scirs2_core::random::seq::SliceRandom;
13use scirs2_core::random::Uniform;
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
18pub enum BalancingStrategy {
19 RandomOversample,
21 RandomUndersample,
23 SMOTE {
25 k_neighbors: usize,
27 },
28}
29
30#[allow(dead_code)]
57pub fn random_oversample(
58 data: &Array2<f64>,
59 targets: &Array1<f64>,
60 random_seed: Option<u64>,
61) -> Result<(Array2<f64>, Array1<f64>)> {
62 if data.nrows() != targets.len() {
63 return Err(DatasetsError::InvalidFormat(
64 "Data rows and targets length must match".to_string(),
65 ));
66 }
67
68 if data.is_empty() || targets.is_empty() {
69 return Err(DatasetsError::InvalidFormat(
70 "Data and targets cannot be empty".to_string(),
71 ));
72 }
73
74 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
76 for (i, &target) in targets.iter().enumerate() {
77 let class = target.round() as i64;
78 class_indices.entry(class).or_default().push(i);
79 }
80
81 let max_class_size = class_indices.values().map(|v| v.len()).max().unwrap();
83
84 let mut rng = match random_seed {
85 Some(_seed) => StdRng::seed_from_u64(_seed),
86 None => {
87 let mut r = thread_rng();
88 StdRng::seed_from_u64(r.next_u64())
89 }
90 };
91
92 let mut resampled_indices = Vec::new();
94
95 for (_, indices) in class_indices {
96 let class_size = indices.len();
97
98 resampled_indices.extend(&indices);
100
101 if class_size < max_class_size {
103 let samples_needed = max_class_size - class_size;
104 for _ in 0..samples_needed {
105 let random_idx = rng.sample(Uniform::new(0, class_size).unwrap());
106 resampled_indices.push(indices[random_idx]);
107 }
108 }
109 }
110
111 let resampled_data = data.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
113 let resampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
114
115 Ok((resampled_data, resampled_targets))
116}
117
118#[allow(dead_code)]
145pub fn random_undersample(
146 data: &Array2<f64>,
147 targets: &Array1<f64>,
148 random_seed: Option<u64>,
149) -> Result<(Array2<f64>, Array1<f64>)> {
150 if data.nrows() != targets.len() {
151 return Err(DatasetsError::InvalidFormat(
152 "Data rows and targets length must match".to_string(),
153 ));
154 }
155
156 if data.is_empty() || targets.is_empty() {
157 return Err(DatasetsError::InvalidFormat(
158 "Data and targets cannot be empty".to_string(),
159 ));
160 }
161
162 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
164 for (i, &target) in targets.iter().enumerate() {
165 let class = target.round() as i64;
166 class_indices.entry(class).or_default().push(i);
167 }
168
169 let min_class_size = class_indices.values().map(|v| v.len()).min().unwrap();
171
172 let mut rng = match random_seed {
173 Some(_seed) => StdRng::seed_from_u64(_seed),
174 None => {
175 let mut r = thread_rng();
176 StdRng::seed_from_u64(r.next_u64())
177 }
178 };
179
180 let mut undersampled_indices = Vec::new();
182
183 for (_, mut indices) in class_indices {
184 if indices.len() > min_class_size {
185 indices.shuffle(&mut rng);
187 undersampled_indices.extend(&indices[0..min_class_size]);
188 } else {
189 undersampled_indices.extend(&indices);
191 }
192 }
193
194 let undersampled_data = data.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
196 let undersampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
197
198 Ok((undersampled_data, undersampled_targets))
199}
200
201#[allow(dead_code)]
232pub fn generate_synthetic_samples(
233 data: &Array2<f64>,
234 targets: &Array1<f64>,
235 target_class: f64,
236 n_synthetic: usize,
237 k_neighbors: usize,
238 random_seed: Option<u64>,
239) -> Result<(Array2<f64>, Array1<f64>)> {
240 if data.nrows() != targets.len() {
241 return Err(DatasetsError::InvalidFormat(
242 "Data rows and targets length must match".to_string(),
243 ));
244 }
245
246 if n_synthetic == 0 {
247 return Err(DatasetsError::InvalidFormat(
248 "Number of _synthetic samples must be > 0".to_string(),
249 ));
250 }
251
252 if k_neighbors == 0 {
253 return Err(DatasetsError::InvalidFormat(
254 "Number of _neighbors must be > 0".to_string(),
255 ));
256 }
257
258 let class_indices: Vec<usize> = targets
260 .iter()
261 .enumerate()
262 .filter(|(_, &target)| (target - target_class).abs() < 1e-10)
263 .map(|(i, _)| i)
264 .collect();
265
266 if class_indices.len() < 2 {
267 return Err(DatasetsError::InvalidFormat(
268 "Need at least 2 samples of the target _class for _synthetic generation".to_string(),
269 ));
270 }
271
272 if k_neighbors >= class_indices.len() {
273 return Err(DatasetsError::InvalidFormat(
274 "k_neighbors must be less than the number of samples in the target _class".to_string(),
275 ));
276 }
277
278 let mut rng = match random_seed {
279 Some(_seed) => StdRng::seed_from_u64(_seed),
280 None => {
281 let mut r = thread_rng();
282 StdRng::seed_from_u64(r.next_u64())
283 }
284 };
285
286 let n_features = data.ncols();
287 let mut synthetic_data = Array2::zeros((n_synthetic, n_features));
288 let synthetic_targets = Array1::from_elem(n_synthetic, target_class);
289
290 for i in 0..n_synthetic {
291 let base_idx = class_indices[rng.sample(Uniform::new(0, class_indices.len()).unwrap())];
293 let base_sample = data.row(base_idx);
294
295 let mut distances: Vec<(usize, f64)> = class_indices
297 .iter()
298 .filter(|&&idx| idx != base_idx)
299 .map(|&idx| {
300 let neighbor = data.row(idx);
301 let distance: f64 = base_sample
302 .iter()
303 .zip(neighbor.iter())
304 .map(|(&a, &b)| (a - b).powi(2))
305 .sum::<f64>()
306 .sqrt();
307 (idx, distance)
308 })
309 .collect();
310
311 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
312 let k_nearest = &distances[0..k_neighbors.min(distances.len())];
313
314 let neighbor_idx = k_nearest[rng.sample(Uniform::new(0, k_nearest.len()).unwrap())].0;
316 let neighbor_sample = data.row(neighbor_idx);
317
318 let alpha = rng.gen_range(0.0..1.0);
320 for (j, synthetic_feature) in synthetic_data.row_mut(i).iter_mut().enumerate() {
321 *synthetic_feature = base_sample[j] + alpha * (neighbor_sample[j] - base_sample[j]);
322 }
323 }
324
325 Ok((synthetic_data, synthetic_targets))
326}
327
328#[allow(dead_code)]
355pub fn create_balanced_dataset(
356 data: &Array2<f64>,
357 targets: &Array1<f64>,
358 strategy: BalancingStrategy,
359 random_seed: Option<u64>,
360) -> Result<(Array2<f64>, Array1<f64>)> {
361 match strategy {
362 BalancingStrategy::RandomOversample => random_oversample(data, targets, random_seed),
363 BalancingStrategy::RandomUndersample => random_undersample(data, targets, random_seed),
364 BalancingStrategy::SMOTE { k_neighbors } => {
365 let mut class_counts: HashMap<i64, usize> = HashMap::new();
367 for &target in targets.iter() {
368 let class = target.round() as i64;
369 *class_counts.entry(class).or_default() += 1;
370 }
371
372 let max_count = *class_counts.values().max().unwrap();
373 let mut combined_data = data.clone();
374 let mut combined_targets = targets.clone();
375
376 for (&class, &count) in &class_counts {
377 if count < max_count {
378 let samples_needed = max_count - count;
379 let (synthetic_data, synthetic_targets) = generate_synthetic_samples(
380 data,
381 targets,
382 class as f64,
383 samples_needed,
384 k_neighbors,
385 random_seed,
386 )?;
387
388 combined_data = scirs2_core::ndarray::concatenate(
390 scirs2_core::ndarray::Axis(0),
391 &[combined_data.view(), synthetic_data.view()],
392 )
393 .map_err(|_| {
394 DatasetsError::InvalidFormat("Failed to concatenate data".to_string())
395 })?;
396
397 combined_targets = scirs2_core::ndarray::concatenate(
398 scirs2_core::ndarray::Axis(0),
399 &[combined_targets.view(), synthetic_targets.view()],
400 )
401 .map_err(|_| {
402 DatasetsError::InvalidFormat("Failed to concatenate targets".to_string())
403 })?;
404 }
405 }
406
407 Ok((combined_data, combined_targets))
408 }
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use scirs2_core::random::Uniform;
416
417 #[test]
418 fn test_random_oversample() {
419 let data = Array2::from_shape_vec(
420 (6, 2),
421 vec![
422 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
423 ],
424 )
425 .unwrap();
426 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) =
429 random_oversample(&data, &targets, Some(42)).unwrap();
430
431 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
433 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
434 assert_eq!(class_0_count, 4); assert_eq!(class_1_count, 4);
436
437 assert_eq!(balanced_data.nrows(), 8);
439 assert_eq!(balanced_targets.len(), 8);
440
441 assert_eq!(balanced_data.ncols(), 2);
443 }
444
445 #[test]
446 fn test_random_oversample_invalid_params() {
447 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
448 let targets = Array1::from(vec![0.0, 1.0]);
449
450 assert!(random_oversample(&data, &targets, None).is_err());
452
453 let empty_data = Array2::zeros((0, 2));
455 let empty_targets = Array1::from(vec![]);
456 assert!(random_oversample(&empty_data, &empty_targets, None).is_err());
457 }
458
459 #[test]
460 fn test_random_undersample() {
461 let data = Array2::from_shape_vec(
462 (6, 2),
463 vec![
464 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
465 ],
466 )
467 .unwrap();
468 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) =
471 random_undersample(&data, &targets, Some(42)).unwrap();
472
473 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
475 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
476 assert_eq!(class_0_count, 2); assert_eq!(class_1_count, 2); assert_eq!(balanced_data.nrows(), 4);
481 assert_eq!(balanced_targets.len(), 4);
482
483 assert_eq!(balanced_data.ncols(), 2);
485 }
486
487 #[test]
488 fn test_random_undersample_invalid_params() {
489 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
490 let targets = Array1::from(vec![0.0, 1.0]);
491
492 assert!(random_undersample(&data, &targets, None).is_err());
494
495 let empty_data = Array2::zeros((0, 2));
497 let empty_targets = Array1::from(vec![]);
498 assert!(random_undersample(&empty_data, &empty_targets, None).is_err());
499 }
500
501 #[test]
502 fn test_generate_synthetic_samples() {
503 let data =
504 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
505 let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
506
507 let (synthetic_data, synthetic_targets) =
508 generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).unwrap();
509
510 assert_eq!(synthetic_data.nrows(), 2);
512 assert_eq!(synthetic_targets.len(), 2);
513
514 for &target in synthetic_targets.iter() {
516 assert_eq!(target, 0.0);
517 }
518
519 assert_eq!(synthetic_data.ncols(), 2);
521
522 for i in 0..synthetic_data.nrows() {
524 for j in 0..synthetic_data.ncols() {
525 let value = synthetic_data[[i, j]];
526 assert!((0.5..=2.5).contains(&value)); }
528 }
529 }
530
531 #[test]
532 fn test_generate_synthetic_samples_invalid_params() {
533 let data =
534 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
535 let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
536
537 let bad_targets = Array1::from(vec![0.0, 1.0]);
539 assert!(generate_synthetic_samples(&data, &bad_targets, 0.0, 2, 2, None).is_err());
540
541 assert!(generate_synthetic_samples(&data, &targets, 0.0, 0, 2, None).is_err());
543
544 assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 0, None).is_err());
546
547 assert!(generate_synthetic_samples(&data, &targets, 1.0, 2, 2, None).is_err());
549
550 assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 3, None).is_err());
552 }
553
554 #[test]
555 fn test_create_balanced_dataset_random_oversample() {
556 let data = Array2::from_shape_vec(
557 (6, 2),
558 vec![
559 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
560 ],
561 )
562 .unwrap();
563 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
564
565 let (balanced_data, balanced_targets) = create_balanced_dataset(
566 &data,
567 &targets,
568 BalancingStrategy::RandomOversample,
569 Some(42),
570 )
571 .unwrap();
572
573 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
575 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
576 assert_eq!(class_0_count, class_1_count);
577 assert_eq!(balanced_data.nrows(), balanced_targets.len());
578 }
579
580 #[test]
581 fn test_create_balanced_dataset_random_undersample() {
582 let data = Array2::from_shape_vec(
583 (6, 2),
584 vec![
585 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
586 ],
587 )
588 .unwrap();
589 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
590
591 let (balanced_data, balanced_targets) = create_balanced_dataset(
592 &data,
593 &targets,
594 BalancingStrategy::RandomUndersample,
595 Some(42),
596 )
597 .unwrap();
598
599 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
601 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
602 assert_eq!(class_0_count, class_1_count);
603 assert_eq!(balanced_data.nrows(), balanced_targets.len());
604 }
605
606 #[test]
607 fn test_create_balanced_dataset_smote() {
608 let data = Array2::from_shape_vec(
609 (8, 2),
610 vec![
611 1.0, 1.0, 1.5, 1.5, 2.0, 2.0, 2.5, 2.5, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0,
612 ],
613 )
614 .unwrap();
615 let targets = Array1::from(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) = create_balanced_dataset(
618 &data,
619 &targets,
620 BalancingStrategy::SMOTE { k_neighbors: 2 },
621 Some(42),
622 )
623 .unwrap();
624
625 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
627 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
628 assert_eq!(class_0_count, class_1_count);
629 assert_eq!(balanced_data.nrows(), balanced_targets.len());
630 }
631
632 #[test]
633 fn test_balancing_strategy_with_multiple_classes() {
634 let data = Array2::from_shape_vec((9, 2), (0..18).map(|x| x as f64).collect()).unwrap();
636 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
637 let (_over_data, over_targets) = create_balanced_dataset(
641 &data,
642 &targets,
643 BalancingStrategy::RandomOversample,
644 Some(42),
645 )
646 .unwrap();
647
648 let over_class_0_count = over_targets.iter().filter(|&&x| x == 0.0).count();
649 let over_class_1_count = over_targets.iter().filter(|&&x| x == 1.0).count();
650 let over_class_2_count = over_targets.iter().filter(|&&x| x == 2.0).count();
651
652 assert_eq!(over_class_0_count, 4);
654 assert_eq!(over_class_1_count, 4);
655 assert_eq!(over_class_2_count, 4);
656
657 let (_under_data, under_targets) = create_balanced_dataset(
659 &data,
660 &targets,
661 BalancingStrategy::RandomUndersample,
662 Some(42),
663 )
664 .unwrap();
665
666 let under_class_0_count = under_targets.iter().filter(|&&x| x == 0.0).count();
667 let under_class_1_count = under_targets.iter().filter(|&&x| x == 1.0).count();
668 let under_class_2_count = under_targets.iter().filter(|&&x| x == 2.0).count();
669
670 assert_eq!(under_class_0_count, 2);
672 assert_eq!(under_class_1_count, 2);
673 assert_eq!(under_class_2_count, 2);
674 }
675}