1use crate::error::{DatasetsError, Result};
8use ndarray::{Array1, Array2};
9use rand::prelude::*;
10use rand::rng;
11use rand::rngs::StdRng;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Copy)]
16pub enum BalancingStrategy {
17 RandomOversample,
19 RandomUndersample,
21 SMOTE {
23 k_neighbors: usize,
25 },
26}
27
28pub fn random_oversample(
55 data: &Array2<f64>,
56 targets: &Array1<f64>,
57 random_seed: Option<u64>,
58) -> Result<(Array2<f64>, Array1<f64>)> {
59 if data.nrows() != targets.len() {
60 return Err(DatasetsError::InvalidFormat(
61 "Data rows and targets length must match".to_string(),
62 ));
63 }
64
65 if data.is_empty() || targets.is_empty() {
66 return Err(DatasetsError::InvalidFormat(
67 "Data and targets cannot be empty".to_string(),
68 ));
69 }
70
71 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
73 for (i, &target) in targets.iter().enumerate() {
74 let class = target.round() as i64;
75 class_indices.entry(class).or_default().push(i);
76 }
77
78 let max_class_size = class_indices.values().map(|v| v.len()).max().unwrap();
80
81 let mut rng = match random_seed {
82 Some(seed) => StdRng::seed_from_u64(seed),
83 None => {
84 let mut r = rng();
85 StdRng::seed_from_u64(r.next_u64())
86 }
87 };
88
89 let mut resampled_indices = Vec::new();
91
92 for (_, indices) in class_indices {
93 let class_size = indices.len();
94
95 resampled_indices.extend(&indices);
97
98 if class_size < max_class_size {
100 let samples_needed = max_class_size - class_size;
101 for _ in 0..samples_needed {
102 let random_idx = rng.random_range(0..class_size);
103 resampled_indices.push(indices[random_idx]);
104 }
105 }
106 }
107
108 let resampled_data = data.select(ndarray::Axis(0), &resampled_indices);
110 let resampled_targets = targets.select(ndarray::Axis(0), &resampled_indices);
111
112 Ok((resampled_data, resampled_targets))
113}
114
115pub fn random_undersample(
142 data: &Array2<f64>,
143 targets: &Array1<f64>,
144 random_seed: Option<u64>,
145) -> Result<(Array2<f64>, Array1<f64>)> {
146 if data.nrows() != targets.len() {
147 return Err(DatasetsError::InvalidFormat(
148 "Data rows and targets length must match".to_string(),
149 ));
150 }
151
152 if data.is_empty() || targets.is_empty() {
153 return Err(DatasetsError::InvalidFormat(
154 "Data and targets cannot be empty".to_string(),
155 ));
156 }
157
158 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
160 for (i, &target) in targets.iter().enumerate() {
161 let class = target.round() as i64;
162 class_indices.entry(class).or_default().push(i);
163 }
164
165 let min_class_size = class_indices.values().map(|v| v.len()).min().unwrap();
167
168 let mut rng = match random_seed {
169 Some(seed) => StdRng::seed_from_u64(seed),
170 None => {
171 let mut r = rng();
172 StdRng::seed_from_u64(r.next_u64())
173 }
174 };
175
176 let mut undersampled_indices = Vec::new();
178
179 for (_, mut indices) in class_indices {
180 if indices.len() > min_class_size {
181 indices.shuffle(&mut rng);
183 undersampled_indices.extend(&indices[0..min_class_size]);
184 } else {
185 undersampled_indices.extend(&indices);
187 }
188 }
189
190 let undersampled_data = data.select(ndarray::Axis(0), &undersampled_indices);
192 let undersampled_targets = targets.select(ndarray::Axis(0), &undersampled_indices);
193
194 Ok((undersampled_data, undersampled_targets))
195}
196
197pub fn generate_synthetic_samples(
228 data: &Array2<f64>,
229 targets: &Array1<f64>,
230 target_class: f64,
231 n_synthetic: usize,
232 k_neighbors: usize,
233 random_seed: Option<u64>,
234) -> Result<(Array2<f64>, Array1<f64>)> {
235 if data.nrows() != targets.len() {
236 return Err(DatasetsError::InvalidFormat(
237 "Data rows and targets length must match".to_string(),
238 ));
239 }
240
241 if n_synthetic == 0 {
242 return Err(DatasetsError::InvalidFormat(
243 "Number of synthetic samples must be > 0".to_string(),
244 ));
245 }
246
247 if k_neighbors == 0 {
248 return Err(DatasetsError::InvalidFormat(
249 "Number of neighbors must be > 0".to_string(),
250 ));
251 }
252
253 let class_indices: Vec<usize> = targets
255 .iter()
256 .enumerate()
257 .filter(|(_, &target)| (target - target_class).abs() < 1e-10)
258 .map(|(i, _)| i)
259 .collect();
260
261 if class_indices.len() < 2 {
262 return Err(DatasetsError::InvalidFormat(
263 "Need at least 2 samples of the target class for synthetic generation".to_string(),
264 ));
265 }
266
267 if k_neighbors >= class_indices.len() {
268 return Err(DatasetsError::InvalidFormat(
269 "k_neighbors must be less than the number of samples in the target class".to_string(),
270 ));
271 }
272
273 let mut rng = match random_seed {
274 Some(seed) => StdRng::seed_from_u64(seed),
275 None => {
276 let mut r = rng();
277 StdRng::seed_from_u64(r.next_u64())
278 }
279 };
280
281 let n_features = data.ncols();
282 let mut synthetic_data = Array2::zeros((n_synthetic, n_features));
283 let synthetic_targets = Array1::from_elem(n_synthetic, target_class);
284
285 for i in 0..n_synthetic {
286 let base_idx = class_indices[rng.random_range(0..class_indices.len())];
288 let base_sample = data.row(base_idx);
289
290 let mut distances: Vec<(usize, f64)> = class_indices
292 .iter()
293 .filter(|&&idx| idx != base_idx)
294 .map(|&idx| {
295 let neighbor = data.row(idx);
296 let distance: f64 = base_sample
297 .iter()
298 .zip(neighbor.iter())
299 .map(|(&a, &b)| (a - b).powi(2))
300 .sum::<f64>()
301 .sqrt();
302 (idx, distance)
303 })
304 .collect();
305
306 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
307 let k_nearest = &distances[0..k_neighbors.min(distances.len())];
308
309 let neighbor_idx = k_nearest[rng.random_range(0..k_nearest.len())].0;
311 let neighbor_sample = data.row(neighbor_idx);
312
313 let alpha = rng.random_range(0.0..1.0);
315 for (j, synthetic_feature) in synthetic_data.row_mut(i).iter_mut().enumerate() {
316 *synthetic_feature = base_sample[j] + alpha * (neighbor_sample[j] - base_sample[j]);
317 }
318 }
319
320 Ok((synthetic_data, synthetic_targets))
321}
322
323pub fn create_balanced_dataset(
350 data: &Array2<f64>,
351 targets: &Array1<f64>,
352 strategy: BalancingStrategy,
353 random_seed: Option<u64>,
354) -> Result<(Array2<f64>, Array1<f64>)> {
355 match strategy {
356 BalancingStrategy::RandomOversample => random_oversample(data, targets, random_seed),
357 BalancingStrategy::RandomUndersample => random_undersample(data, targets, random_seed),
358 BalancingStrategy::SMOTE { k_neighbors } => {
359 let mut class_counts: HashMap<i64, usize> = HashMap::new();
361 for &target in targets.iter() {
362 let class = target.round() as i64;
363 *class_counts.entry(class).or_default() += 1;
364 }
365
366 let max_count = *class_counts.values().max().unwrap();
367 let mut combined_data = data.clone();
368 let mut combined_targets = targets.clone();
369
370 for (&class, &count) in &class_counts {
371 if count < max_count {
372 let samples_needed = max_count - count;
373 let (synthetic_data, synthetic_targets) = generate_synthetic_samples(
374 data,
375 targets,
376 class as f64,
377 samples_needed,
378 k_neighbors,
379 random_seed,
380 )?;
381
382 combined_data = ndarray::concatenate(
384 ndarray::Axis(0),
385 &[combined_data.view(), synthetic_data.view()],
386 )
387 .map_err(|_| {
388 DatasetsError::InvalidFormat("Failed to concatenate data".to_string())
389 })?;
390
391 combined_targets = ndarray::concatenate(
392 ndarray::Axis(0),
393 &[combined_targets.view(), synthetic_targets.view()],
394 )
395 .map_err(|_| {
396 DatasetsError::InvalidFormat("Failed to concatenate targets".to_string())
397 })?;
398 }
399 }
400
401 Ok((combined_data, combined_targets))
402 }
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_random_oversample() {
412 let data = Array2::from_shape_vec(
413 (6, 2),
414 vec![
415 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
416 ],
417 )
418 .unwrap();
419 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) =
422 random_oversample(&data, &targets, Some(42)).unwrap();
423
424 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
426 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
427 assert_eq!(class_0_count, 4); assert_eq!(class_1_count, 4);
429
430 assert_eq!(balanced_data.nrows(), 8);
432 assert_eq!(balanced_targets.len(), 8);
433
434 assert_eq!(balanced_data.ncols(), 2);
436 }
437
438 #[test]
439 fn test_random_oversample_invalid_params() {
440 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
441 let targets = Array1::from(vec![0.0, 1.0]);
442
443 assert!(random_oversample(&data, &targets, None).is_err());
445
446 let empty_data = Array2::zeros((0, 2));
448 let empty_targets = Array1::from(vec![]);
449 assert!(random_oversample(&empty_data, &empty_targets, None).is_err());
450 }
451
452 #[test]
453 fn test_random_undersample() {
454 let data = Array2::from_shape_vec(
455 (6, 2),
456 vec![
457 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
458 ],
459 )
460 .unwrap();
461 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) =
464 random_undersample(&data, &targets, Some(42)).unwrap();
465
466 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
468 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
469 assert_eq!(class_0_count, 2); assert_eq!(class_1_count, 2); assert_eq!(balanced_data.nrows(), 4);
474 assert_eq!(balanced_targets.len(), 4);
475
476 assert_eq!(balanced_data.ncols(), 2);
478 }
479
480 #[test]
481 fn test_random_undersample_invalid_params() {
482 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
483 let targets = Array1::from(vec![0.0, 1.0]);
484
485 assert!(random_undersample(&data, &targets, None).is_err());
487
488 let empty_data = Array2::zeros((0, 2));
490 let empty_targets = Array1::from(vec![]);
491 assert!(random_undersample(&empty_data, &empty_targets, None).is_err());
492 }
493
494 #[test]
495 fn test_generate_synthetic_samples() {
496 let data =
497 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
498 let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
499
500 let (synthetic_data, synthetic_targets) =
501 generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42)).unwrap();
502
503 assert_eq!(synthetic_data.nrows(), 2);
505 assert_eq!(synthetic_targets.len(), 2);
506
507 for &target in synthetic_targets.iter() {
509 assert_eq!(target, 0.0);
510 }
511
512 assert_eq!(synthetic_data.ncols(), 2);
514
515 for i in 0..synthetic_data.nrows() {
517 for j in 0..synthetic_data.ncols() {
518 let value = synthetic_data[[i, j]];
519 assert!((0.5..=2.5).contains(&value)); }
521 }
522 }
523
524 #[test]
525 fn test_generate_synthetic_samples_invalid_params() {
526 let data =
527 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5]).unwrap();
528 let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
529
530 let bad_targets = Array1::from(vec![0.0, 1.0]);
532 assert!(generate_synthetic_samples(&data, &bad_targets, 0.0, 2, 2, None).is_err());
533
534 assert!(generate_synthetic_samples(&data, &targets, 0.0, 0, 2, None).is_err());
536
537 assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 0, None).is_err());
539
540 assert!(generate_synthetic_samples(&data, &targets, 1.0, 2, 2, None).is_err());
542
543 assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 3, None).is_err());
545 }
546
547 #[test]
548 fn test_create_balanced_dataset_random_oversample() {
549 let data = Array2::from_shape_vec(
550 (6, 2),
551 vec![
552 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
553 ],
554 )
555 .unwrap();
556 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
557
558 let (balanced_data, balanced_targets) = create_balanced_dataset(
559 &data,
560 &targets,
561 BalancingStrategy::RandomOversample,
562 Some(42),
563 )
564 .unwrap();
565
566 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
568 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
569 assert_eq!(class_0_count, class_1_count);
570 assert_eq!(balanced_data.nrows(), balanced_targets.len());
571 }
572
573 #[test]
574 fn test_create_balanced_dataset_random_undersample() {
575 let data = Array2::from_shape_vec(
576 (6, 2),
577 vec![
578 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
579 ],
580 )
581 .unwrap();
582 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
583
584 let (balanced_data, balanced_targets) = create_balanced_dataset(
585 &data,
586 &targets,
587 BalancingStrategy::RandomUndersample,
588 Some(42),
589 )
590 .unwrap();
591
592 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
594 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
595 assert_eq!(class_0_count, class_1_count);
596 assert_eq!(balanced_data.nrows(), balanced_targets.len());
597 }
598
599 #[test]
600 fn test_create_balanced_dataset_smote() {
601 let data = Array2::from_shape_vec(
602 (8, 2),
603 vec![
604 1.0, 1.0, 1.5, 1.5, 2.0, 2.0, 2.5, 2.5, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0,
605 ],
606 )
607 .unwrap();
608 let targets = Array1::from(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) = create_balanced_dataset(
611 &data,
612 &targets,
613 BalancingStrategy::SMOTE { k_neighbors: 2 },
614 Some(42),
615 )
616 .unwrap();
617
618 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
620 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
621 assert_eq!(class_0_count, class_1_count);
622 assert_eq!(balanced_data.nrows(), balanced_targets.len());
623 }
624
625 #[test]
626 fn test_balancing_strategy_with_multiple_classes() {
627 let data = Array2::from_shape_vec((9, 2), (0..18).map(|x| x as f64).collect()).unwrap();
629 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
630 let (_over_data, over_targets) = create_balanced_dataset(
634 &data,
635 &targets,
636 BalancingStrategy::RandomOversample,
637 Some(42),
638 )
639 .unwrap();
640
641 let over_class_0_count = over_targets.iter().filter(|&&x| x == 0.0).count();
642 let over_class_1_count = over_targets.iter().filter(|&&x| x == 1.0).count();
643 let over_class_2_count = over_targets.iter().filter(|&&x| x == 2.0).count();
644
645 assert_eq!(over_class_0_count, 4);
647 assert_eq!(over_class_1_count, 4);
648 assert_eq!(over_class_2_count, 4);
649
650 let (_under_data, under_targets) = create_balanced_dataset(
652 &data,
653 &targets,
654 BalancingStrategy::RandomUndersample,
655 Some(42),
656 )
657 .unwrap();
658
659 let under_class_0_count = under_targets.iter().filter(|&&x| x == 0.0).count();
660 let under_class_1_count = under_targets.iter().filter(|&&x| x == 1.0).count();
661 let under_class_2_count = under_targets.iter().filter(|&&x| x == 2.0).count();
662
663 assert_eq!(under_class_0_count, 2);
665 assert_eq!(under_class_1_count, 2);
666 assert_eq!(under_class_2_count, 2);
667 }
668}