1use crate::bagging::BaggingClassifier;
8use scirs2_core::ndarray::{Array1, Array2};
11use sklears_core::{
12 error::{Result as SklResult, SklearsError},
13 prelude::Predict,
14 traits::{Estimator, Fit, Trained, Untrained},
15};
16use std::collections::HashMap;
17
18fn gen_f64(rng: &mut impl scirs2_core::random::RngCore) -> f64 {
20 let mut bytes = [0u8; 8];
21 rng.fill_bytes(&mut bytes);
22 f64::from_le_bytes(bytes) / f64::from_le_bytes([255u8; 8])
23}
24
25fn gen_range_usize(
27 rng: &mut impl scirs2_core::random::RngCore,
28 range: std::ops::Range<usize>,
29) -> usize {
30 let mut bytes = [0u8; 8];
31 rng.fill_bytes(&mut bytes);
32 let val = u64::from_le_bytes(bytes);
33 range.start + (val as usize % (range.end - range.start))
34}
35
36#[derive(Debug, Clone)]
38pub struct ImbalancedEnsembleConfig {
39 pub n_estimators: usize,
41 pub sampling_strategy: SamplingStrategy,
43 pub cost_sensitive_config: Option<CostSensitiveConfig>,
45 pub combination_strategy: CombinationStrategy,
47 pub balanced_bootstrap: bool,
49 pub threshold_moving: Option<ThresholdMovingStrategy>,
51 pub under_sampling_ratio: f64,
53 pub over_sampling_ratio: f64,
55 pub smote_config: Option<SMOTEConfig>,
57 pub random_state: Option<u64>,
59}
60
61impl Default for ImbalancedEnsembleConfig {
62 fn default() -> Self {
63 Self {
64 n_estimators: 10,
65 sampling_strategy: SamplingStrategy::SMOTE,
66 cost_sensitive_config: None,
67 combination_strategy: CombinationStrategy::WeightedVoting,
68 balanced_bootstrap: true,
69 threshold_moving: Some(ThresholdMovingStrategy::Youden),
70 under_sampling_ratio: 0.5,
71 over_sampling_ratio: 1.0,
72 smote_config: Some(SMOTEConfig::default()),
73 random_state: None,
74 }
75 }
76}
77
78#[derive(Debug, Clone, PartialEq)]
80pub enum SamplingStrategy {
81 None,
83 RandomUnderSampling,
85 RandomOverSampling,
87 SMOTE,
89 ADASYN,
91 BorderlineSMOTE,
93 SVMSMOTE,
95 EditedNearestNeighbors,
97 TomekLinks,
99 NeighborhoodCleaning,
101 SMOTEENN,
103 SMOTETomek,
105}
106
107#[derive(Debug, Clone)]
109pub struct CostSensitiveConfig {
110 pub cost_matrix: Array2<f64>,
112 pub class_balanced_weights: bool,
114 pub class_weights: Option<HashMap<usize, f64>>,
116 pub algorithm: CostSensitiveAlgorithm,
118}
119
120#[derive(Debug, Clone, PartialEq)]
122pub enum CostSensitiveAlgorithm {
123 CostSensitiveDecisionTree,
124 MetaCost,
125 CostSensitiveBoosting,
126 ThresholdMoving,
127}
128
129#[derive(Debug, Clone, PartialEq)]
131pub enum CombinationStrategy {
132 MajorityVoting,
134 WeightedVoting,
136 ImbalancedStacking,
138 DynamicSelection,
140 BayesianCombination,
142}
143
144#[derive(Debug, Clone, PartialEq)]
146pub enum ThresholdMovingStrategy {
147 Youden,
149 F1Optimal,
151 PrecisionRecallOptimal,
153 CostSensitive,
155 BalancedAccuracy,
157}
158
159#[derive(Debug, Clone)]
161pub struct SMOTEConfig {
162 pub k_neighbors: usize,
164 pub sampling_strategy: f64,
166 pub random_state: Option<u64>,
168 pub selective: bool,
170 pub borderline_mode: BorderlineMode,
172}
173
174impl Default for SMOTEConfig {
175 fn default() -> Self {
176 Self {
177 k_neighbors: 5,
178 sampling_strategy: 1.0,
179 random_state: None,
180 selective: false,
181 borderline_mode: BorderlineMode::Borderline1,
182 }
183 }
184}
185
186#[derive(Debug, Clone, PartialEq)]
188pub enum BorderlineMode {
189 Borderline1,
191 Borderline2,
193}
194
195pub struct ImbalancedEnsembleClassifier<State = Untrained> {
197 config: ImbalancedEnsembleConfig,
198 state: std::marker::PhantomData<State>,
199 base_classifiers: Option<Vec<BaggingClassifier<Trained>>>,
201 class_weights: Option<HashMap<usize, f64>>,
202 optimal_thresholds: Option<HashMap<usize, f64>>,
203 class_distributions: Option<HashMap<usize, usize>>,
204 sampling_results: Option<Vec<SamplingResult>>,
205}
206
207#[derive(Debug, Clone)]
209pub struct SamplingResult {
210 pub original_distribution: HashMap<usize, usize>,
212 pub resampled_distribution: HashMap<usize, usize>,
214 pub quality_metrics: SamplingQualityMetrics,
216}
217
218#[derive(Debug, Clone)]
220pub struct SamplingQualityMetrics {
221 pub balance_ratio: f64,
223 pub information_preservation: f64,
225 pub diversity_increase: f64,
227 pub computational_overhead: f64,
229}
230
231pub struct SMOTESampler {
233 config: SMOTEConfig,
234 rng: scirs2_core::random::CoreRandom<scirs2_core::random::rngs::StdRng>,
235}
236
237impl ImbalancedEnsembleConfig {
238 pub fn builder() -> ImbalancedEnsembleConfigBuilder {
239 ImbalancedEnsembleConfigBuilder::default()
240 }
241}
242
243#[derive(Default)]
244pub struct ImbalancedEnsembleConfigBuilder {
245 config: ImbalancedEnsembleConfig,
246}
247
248impl ImbalancedEnsembleConfigBuilder {
249 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
250 self.config.n_estimators = n_estimators;
251 self
252 }
253
254 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
255 self.config.sampling_strategy = strategy;
256 self
257 }
258
259 pub fn cost_sensitive_config(mut self, config: CostSensitiveConfig) -> Self {
260 self.config.cost_sensitive_config = Some(config);
261 self
262 }
263
264 pub fn combination_strategy(mut self, strategy: CombinationStrategy) -> Self {
265 self.config.combination_strategy = strategy;
266 self
267 }
268
269 pub fn balanced_bootstrap(mut self, balanced: bool) -> Self {
270 self.config.balanced_bootstrap = balanced;
271 self
272 }
273
274 pub fn threshold_moving(mut self, strategy: ThresholdMovingStrategy) -> Self {
275 self.config.threshold_moving = Some(strategy);
276 self
277 }
278
279 pub fn smote_config(mut self, config: SMOTEConfig) -> Self {
280 self.config.smote_config = Some(config);
281 self
282 }
283
284 pub fn random_state(mut self, seed: u64) -> Self {
285 self.config.random_state = Some(seed);
286 self
287 }
288
289 pub fn build(self) -> ImbalancedEnsembleConfig {
290 self.config
291 }
292}
293
294impl ImbalancedEnsembleClassifier<Untrained> {
295 pub fn new(config: ImbalancedEnsembleConfig) -> Self {
296 Self {
297 config,
298 state: std::marker::PhantomData,
299 base_classifiers: None,
300 class_weights: Some(HashMap::new()),
301 optimal_thresholds: Some(HashMap::new()),
302 class_distributions: Some(HashMap::new()),
303 sampling_results: Some(Vec::new()),
304 }
305 }
306
307 pub fn builder() -> ImbalancedEnsembleClassifierBuilder {
308 ImbalancedEnsembleClassifierBuilder::new()
309 }
310
311 fn analyze_class_distribution(&mut self, y: &[usize]) -> SklResult<()> {
313 self.class_distributions.as_mut().unwrap().clear();
314
315 for &class in y {
316 *self
317 .class_distributions
318 .as_mut()
319 .unwrap()
320 .entry(class)
321 .or_insert(0) += 1;
322 }
323
324 let total_samples = y.len();
326 let n_classes = self.class_distributions.as_ref().unwrap().len();
327
328 for (&class, &count) in self.class_distributions.as_ref().unwrap() {
329 let weight = total_samples as f64 / (n_classes as f64 * count as f64);
330 self.class_weights.as_mut().unwrap().insert(class, weight);
331 }
332
333 Ok(())
334 }
335
336 fn apply_sampling_strategy(
338 &mut self,
339 X: &Array2<f64>,
340 y: &[usize],
341 ) -> SklResult<(Array2<f64>, Vec<usize>)> {
342 match self.config.sampling_strategy {
343 SamplingStrategy::None => Ok((X.clone(), y.to_vec())),
344 SamplingStrategy::RandomUnderSampling => self.random_under_sampling(X, y),
345 SamplingStrategy::RandomOverSampling => self.random_over_sampling(X, y),
346 SamplingStrategy::SMOTE => self.smote_sampling(X, y),
347 SamplingStrategy::ADASYN => self.adasyn_sampling(X, y),
348 SamplingStrategy::TomekLinks => self.tomek_links_sampling(X, y),
349 SamplingStrategy::SMOTEENN => self.smoteenn_sampling(X, y),
350 _ => {
351 self.smote_sampling(X, y)
353 }
354 }
355 }
356
357 fn random_under_sampling(
359 &self,
360 X: &Array2<f64>,
361 y: &[usize],
362 ) -> SklResult<(Array2<f64>, Vec<usize>)> {
363 let mut rng = if let Some(seed) = self.config.random_state {
364 scirs2_core::random::seeded_rng(seed)
365 } else {
366 scirs2_core::random::seeded_rng(42)
367 };
368
369 let min_class_size = self
371 .class_distributions
372 .as_ref()
373 .unwrap()
374 .values()
375 .min()
376 .copied()
377 .unwrap_or(0);
378 let target_size =
379 (min_class_size as f64 * (1.0 + self.config.under_sampling_ratio)) as usize;
380
381 let mut resampled_indices = Vec::new();
382
383 for (&class, &count) in self.class_distributions.as_ref().unwrap() {
384 let class_indices: Vec<usize> = y
385 .iter()
386 .enumerate()
387 .filter(|(_, &c)| c == class)
388 .map(|(i, _)| i)
389 .collect();
390
391 let sample_size = if count > target_size {
392 target_size
393 } else {
394 count
395 };
396
397 let mut selected_indices = class_indices;
399 selected_indices.truncate(sample_size);
400
401 for i in (1..selected_indices.len()).rev() {
403 let j = gen_range_usize(&mut rng, 0..(i + 1));
404 selected_indices.swap(i, j);
405 }
406
407 resampled_indices.extend(selected_indices);
408 }
409
410 let n_features = X.shape()[1];
412 let mut resampled_X = Vec::with_capacity(resampled_indices.len() * n_features);
413 let mut resampled_y = Vec::with_capacity(resampled_indices.len());
414
415 for &idx in &resampled_indices {
416 for j in 0..n_features {
417 resampled_X.push(X[[idx, j]]);
418 }
419 resampled_y.push(y[idx]);
420 }
421
422 let X_resampled =
423 Array2::from_shape_vec((resampled_indices.len(), n_features), resampled_X)?;
424
425 Ok((X_resampled, resampled_y))
426 }
427
428 fn random_over_sampling(
430 &self,
431 X: &Array2<f64>,
432 y: &[usize],
433 ) -> SklResult<(Array2<f64>, Vec<usize>)> {
434 let mut rng = if let Some(seed) = self.config.random_state {
435 scirs2_core::random::seeded_rng(seed)
436 } else {
437 scirs2_core::random::seeded_rng(42)
438 };
439
440 let max_class_size = self
442 .class_distributions
443 .as_ref()
444 .unwrap()
445 .values()
446 .max()
447 .copied()
448 .unwrap_or(0);
449 let target_size = (max_class_size as f64 * self.config.over_sampling_ratio) as usize;
450
451 let mut resampled_X = Vec::new();
452 let mut resampled_y = Vec::new();
453
454 for (&class, &count) in self.class_distributions.as_ref().unwrap() {
455 let class_indices: Vec<usize> = y
456 .iter()
457 .enumerate()
458 .filter(|(_, &c)| c == class)
459 .map(|(i, _)| i)
460 .collect();
461
462 for &idx in &class_indices {
464 for j in 0..X.shape()[1] {
465 resampled_X.push(X[[idx, j]]);
466 }
467 resampled_y.push(class);
468 }
469
470 if count < target_size {
472 let additional_samples = target_size - count;
473 for _ in 0..additional_samples {
474 let random_idx =
475 class_indices[gen_range_usize(&mut rng, 0..class_indices.len())];
476 for j in 0..X.shape()[1] {
477 resampled_X.push(X[[random_idx, j]]);
478 }
479 resampled_y.push(class);
480 }
481 }
482 }
483
484 let n_features = X.shape()[1];
485 let n_samples = resampled_y.len();
486 let X_resampled = Array2::from_shape_vec((n_samples, n_features), resampled_X)?;
487
488 Ok((X_resampled, resampled_y))
489 }
490
491 fn smote_sampling(&self, X: &Array2<f64>, y: &[usize]) -> SklResult<(Array2<f64>, Vec<usize>)> {
493 let default_config = SMOTEConfig::default();
494 let smote_config = self.config.smote_config.as_ref().unwrap_or(&default_config);
495
496 let mut sampler = SMOTESampler::new(smote_config.clone());
497 sampler.fit_resample(X, y)
498 }
499
500 fn adasyn_sampling(
502 &self,
503 X: &Array2<f64>,
504 y: &[usize],
505 ) -> SklResult<(Array2<f64>, Vec<usize>)> {
506 self.smote_sampling(X, y)
509 }
510
511 #[allow(non_snake_case)]
513 fn tomek_links_sampling(
514 &self,
515 X: &Array2<f64>,
516 y: &[usize],
517 ) -> SklResult<(Array2<f64>, Vec<usize>)> {
518 let mut keep_indices = Vec::new();
522
523 for i in 0..X.shape()[0] {
524 let mut nearest_distance = f64::INFINITY;
525 let mut nearest_class = y[i];
526
527 for j in 0..X.shape()[0] {
529 if i != j {
530 let distance = self.euclidean_distance(X, i, j);
531 if distance < nearest_distance {
532 nearest_distance = distance;
533 nearest_class = y[j];
534 }
535 }
536 }
537
538 if nearest_class == y[i] {
540 keep_indices.push(i);
541 }
542 }
543
544 let n_features = X.shape()[1];
546 let mut cleaned_X = Vec::with_capacity(keep_indices.len() * n_features);
547 let mut cleaned_y = Vec::with_capacity(keep_indices.len());
548
549 for &idx in &keep_indices {
550 for j in 0..n_features {
551 cleaned_X.push(X[[idx, j]]);
552 }
553 cleaned_y.push(y[idx]);
554 }
555
556 let X_cleaned = Array2::from_shape_vec((keep_indices.len(), n_features), cleaned_X)?;
557
558 Ok((X_cleaned, cleaned_y))
559 }
560
561 fn smoteenn_sampling(
563 &self,
564 X: &Array2<f64>,
565 y: &[usize],
566 ) -> SklResult<(Array2<f64>, Vec<usize>)> {
567 let (X_smote, y_smote) = self.smote_sampling(X, y)?;
569
570 self.edited_nearest_neighbors_cleaning(&X_smote, &y_smote)
572 }
573
574 #[allow(non_snake_case)]
576 fn edited_nearest_neighbors_cleaning(
577 &self,
578 X: &Array2<f64>,
579 y: &[usize],
580 ) -> SklResult<(Array2<f64>, Vec<usize>)> {
581 let k = 3; let mut keep_indices = Vec::new();
583
584 for i in 0..X.shape()[0] {
585 let neighbors = self.find_k_nearest_neighbors(X, i, k);
586 let neighbor_classes: Vec<usize> = neighbors.iter().map(|&idx| y[idx]).collect();
587
588 let same_class_count = neighbor_classes.iter().filter(|&&c| c == y[i]).count();
590
591 if same_class_count > neighbors.len() / 2 {
592 keep_indices.push(i);
593 }
594 }
595
596 let n_features = X.shape()[1];
598 let mut cleaned_X = Vec::with_capacity(keep_indices.len() * n_features);
599 let mut cleaned_y = Vec::with_capacity(keep_indices.len());
600
601 for &idx in &keep_indices {
602 for j in 0..n_features {
603 cleaned_X.push(X[[idx, j]]);
604 }
605 cleaned_y.push(y[idx]);
606 }
607
608 let X_cleaned = Array2::from_shape_vec((keep_indices.len(), n_features), cleaned_X)?;
609
610 Ok((X_cleaned, cleaned_y))
611 }
612
613 fn find_k_nearest_neighbors(&self, X: &Array2<f64>, sample_idx: usize, k: usize) -> Vec<usize> {
615 let mut distances: Vec<(f64, usize)> = Vec::new();
616
617 for i in 0..X.shape()[0] {
618 if i != sample_idx {
619 let distance = self.euclidean_distance(X, sample_idx, i);
620 distances.push((distance, i));
621 }
622 }
623
624 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
625 distances.iter().take(k).map(|(_, idx)| *idx).collect()
626 }
627
628 fn euclidean_distance(&self, X: &Array2<f64>, i: usize, j: usize) -> f64 {
630 let mut sum_squared = 0.0;
631 for k in 0..X.shape()[1] {
632 let diff = X[[i, k]] - X[[j, k]];
633 sum_squared += diff * diff;
634 }
635 sum_squared.sqrt()
636 }
637
638 fn optimize_thresholds(&mut self, X: &Array2<f64>, y: &[usize]) -> SklResult<()> {
640 if let Some(ref strategy) = self.config.threshold_moving {
641 match strategy {
642 ThresholdMovingStrategy::Youden => {
643 self.optimize_youden_threshold(X, y)?;
644 }
645 ThresholdMovingStrategy::F1Optimal => {
646 self.optimize_f1_threshold(X, y)?;
647 }
648 _ => {
649 for &class in self.class_distributions.as_ref().unwrap().keys() {
651 self.optimal_thresholds.as_mut().unwrap().insert(class, 0.5);
652 }
653 }
654 }
655 }
656 Ok(())
657 }
658
659 fn optimize_youden_threshold(&mut self, _X: &Array2<f64>, _y: &[usize]) -> SklResult<()> {
661 for &class in self.class_distributions.as_ref().unwrap().keys() {
663 self.optimal_thresholds.as_mut().unwrap().insert(class, 0.5);
664 }
665 Ok(())
666 }
667
668 fn optimize_f1_threshold(&mut self, _X: &Array2<f64>, _y: &[usize]) -> SklResult<()> {
670 for &class in self.class_distributions.as_ref().unwrap().keys() {
672 self.optimal_thresholds.as_mut().unwrap().insert(class, 0.5);
673 }
674 Ok(())
675 }
676
677 #[allow(non_snake_case)]
679 fn create_balanced_bootstrap(
680 &self,
681 X: &Array2<f64>,
682 y: &[usize],
683 ) -> SklResult<Vec<(Array2<f64>, Vec<usize>)>> {
684 let mut bootstrap_samples = Vec::new();
685 let mut rng = if let Some(seed) = self.config.random_state {
686 scirs2_core::random::seeded_rng(seed)
687 } else {
688 scirs2_core::random::seeded_rng(42)
689 };
690
691 for _ in 0..self.config.n_estimators {
692 let mut sample_indices = Vec::new();
693
694 let samples_per_class = X.shape()[0] / self.class_distributions.as_ref().unwrap().len();
696
697 for &class in self.class_distributions.as_ref().unwrap().keys() {
698 let class_indices: Vec<usize> = y
699 .iter()
700 .enumerate()
701 .filter(|(_, &c)| c == class)
702 .map(|(i, _)| i)
703 .collect();
704
705 for _ in 0..samples_per_class {
707 let random_idx =
708 class_indices[gen_range_usize(&mut rng, 0..class_indices.len())];
709 sample_indices.push(random_idx);
710 }
711 }
712
713 let n_features = X.shape()[1];
715 let mut sample_X = Vec::with_capacity(sample_indices.len() * n_features);
716 let mut sample_y = Vec::with_capacity(sample_indices.len());
717
718 for &idx in &sample_indices {
719 for j in 0..n_features {
720 sample_X.push(X[[idx, j]]);
721 }
722 sample_y.push(y[idx]);
723 }
724
725 let X_sample = Array2::from_shape_vec((sample_indices.len(), n_features), sample_X)?;
726 bootstrap_samples.push((X_sample, sample_y));
727 }
728
729 Ok(bootstrap_samples)
730 }
731
732 pub fn cost_sensitive(cost_matrix: Array2<f64>) -> Self {
734 let cost_config = CostSensitiveConfig {
735 cost_matrix,
736 class_balanced_weights: true,
737 class_weights: None,
738 algorithm: CostSensitiveAlgorithm::CostSensitiveBoosting,
739 };
740
741 let config = ImbalancedEnsembleConfig {
742 cost_sensitive_config: Some(cost_config),
743 combination_strategy: CombinationStrategy::WeightedVoting,
744 ..Default::default()
745 };
746
747 Self::new(config)
748 }
749
750 pub fn cost_sensitive_weights(class_weights: HashMap<usize, f64>) -> Self {
752 let cost_config = CostSensitiveConfig {
753 cost_matrix: Array2::zeros((0, 0)),
754 class_balanced_weights: false,
755 class_weights: Some(class_weights),
756 algorithm: CostSensitiveAlgorithm::CostSensitiveBoosting,
757 };
758
759 let config = ImbalancedEnsembleConfig {
760 cost_sensitive_config: Some(cost_config),
761 combination_strategy: CombinationStrategy::WeightedVoting,
762 ..Default::default()
763 };
764
765 Self::new(config)
766 }
767
768 pub fn smote_ensemble(k_neighbors: usize) -> Self {
770 let smote_config = SMOTEConfig {
771 k_neighbors,
772 sampling_strategy: 1.0,
773 random_state: None,
774 selective: false,
775 borderline_mode: BorderlineMode::Borderline1,
776 };
777
778 let config = ImbalancedEnsembleConfig {
779 sampling_strategy: SamplingStrategy::SMOTE,
780 smote_config: Some(smote_config),
781 balanced_bootstrap: true,
782 ..Default::default()
783 };
784
785 Self::new(config)
786 }
787}
788
789impl SMOTESampler {
790 pub fn new(config: SMOTEConfig) -> Self {
791 let rng = if let Some(seed) = config.random_state {
792 scirs2_core::random::seeded_rng(seed)
793 } else {
794 scirs2_core::random::seeded_rng(42)
795 };
796
797 Self { config, rng }
798 }
799
800 #[allow(non_snake_case)]
802 pub fn fit_resample(
803 &mut self,
804 X: &Array2<f64>,
805 y: &[usize],
806 ) -> SklResult<(Array2<f64>, Vec<usize>)> {
807 let mut class_counts = HashMap::new();
809 for &class in y {
810 *class_counts.entry(class).or_insert(0) += 1;
811 }
812
813 let max_count = *class_counts.values().max().unwrap_or(&0);
815 let target_count = (max_count as f64 * self.config.sampling_strategy) as usize;
816
817 let mut resampled_X = Vec::new();
818 let mut resampled_y = Vec::new();
819
820 for i in 0..X.shape()[0] {
822 for j in 0..X.shape()[1] {
823 resampled_X.push(X[[i, j]]);
824 }
825 resampled_y.push(y[i]);
826 }
827
828 for (&class, &count) in &class_counts {
830 if count < target_count {
831 let n_synthetic = target_count - count;
832 let synthetic_samples =
833 self.generate_synthetic_samples(X, y, class, n_synthetic)?;
834
835 for sample in synthetic_samples {
836 resampled_X.extend(sample);
837 resampled_y.push(class);
838 }
839 }
840 }
841
842 let n_features = X.shape()[1];
843 let n_samples = resampled_y.len();
844 let X_resampled = Array2::from_shape_vec((n_samples, n_features), resampled_X)?;
845
846 Ok((X_resampled, resampled_y))
847 }
848
849 fn generate_synthetic_samples(
851 &mut self,
852 X: &Array2<f64>,
853 y: &[usize],
854 target_class: usize,
855 n_samples: usize,
856 ) -> SklResult<Vec<Vec<f64>>> {
857 let class_indices: Vec<usize> = y
859 .iter()
860 .enumerate()
861 .filter(|(_, &c)| c == target_class)
862 .map(|(i, _)| i)
863 .collect();
864
865 if class_indices.len() < self.config.k_neighbors {
866 return Err(SklearsError::InvalidInput(format!(
867 "Not enough samples of class {} for SMOTE",
868 target_class
869 )));
870 }
871
872 let mut synthetic_samples = Vec::new();
873
874 for _ in 0..n_samples {
875 let sample_idx = class_indices[gen_range_usize(&mut self.rng, 0..class_indices.len())];
877
878 let neighbors = self.find_nearest_neighbors(X, &class_indices, sample_idx)?;
880
881 let neighbor_idx = neighbors[gen_range_usize(&mut self.rng, 0..neighbors.len())];
883
884 let synthetic_sample = self.generate_sample_between(X, sample_idx, neighbor_idx);
886 synthetic_samples.push(synthetic_sample);
887 }
888
889 Ok(synthetic_samples)
890 }
891
892 fn find_nearest_neighbors(
894 &self,
895 X: &Array2<f64>,
896 class_indices: &[usize],
897 sample_idx: usize,
898 ) -> SklResult<Vec<usize>> {
899 let mut distances: Vec<(f64, usize)> = Vec::new();
900
901 for &idx in class_indices {
902 if idx != sample_idx {
903 let distance = self.euclidean_distance(X, sample_idx, idx);
904 distances.push((distance, idx));
905 }
906 }
907
908 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
909 let neighbors = distances
910 .iter()
911 .take(self.config.k_neighbors)
912 .map(|(_, idx)| *idx)
913 .collect();
914
915 Ok(neighbors)
916 }
917
918 fn generate_sample_between(&mut self, X: &Array2<f64>, idx1: usize, idx2: usize) -> Vec<f64> {
920 let mut synthetic_sample = Vec::new();
921
922 for j in 0..X.shape()[1] {
923 let x1 = X[[idx1, j]];
924 let x2 = X[[idx2, j]];
925 let random_factor = gen_f64(&mut self.rng);
926
927 let synthetic_value = x1 + random_factor * (x2 - x1);
929 synthetic_sample.push(synthetic_value);
930 }
931
932 synthetic_sample
933 }
934
935 fn euclidean_distance(&self, X: &Array2<f64>, i: usize, j: usize) -> f64 {
937 let mut sum_squared = 0.0;
938 for k in 0..X.shape()[1] {
939 let diff = X[[i, k]] - X[[j, k]];
940 sum_squared += diff * diff;
941 }
942 sum_squared.sqrt()
943 }
944}
945
946pub struct ImbalancedEnsembleClassifierBuilder {
947 config: ImbalancedEnsembleConfig,
948}
949
950impl Default for ImbalancedEnsembleClassifierBuilder {
951 fn default() -> Self {
952 Self::new()
953 }
954}
955
956impl ImbalancedEnsembleClassifierBuilder {
957 pub fn new() -> Self {
958 Self {
959 config: ImbalancedEnsembleConfig::default(),
960 }
961 }
962
963 pub fn config(mut self, config: ImbalancedEnsembleConfig) -> Self {
964 self.config = config;
965 self
966 }
967
968 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
969 self.config.n_estimators = n_estimators;
970 self
971 }
972
973 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
974 self.config.sampling_strategy = strategy;
975 self
976 }
977
978 pub fn balanced_bootstrap(mut self, balanced: bool) -> Self {
979 self.config.balanced_bootstrap = balanced;
980 self
981 }
982
983 pub fn build(self) -> ImbalancedEnsembleClassifier<Untrained> {
984 ImbalancedEnsembleClassifier::new(self.config)
985 }
986}
987
988impl Estimator for ImbalancedEnsembleClassifier<Untrained> {
989 type Config = ImbalancedEnsembleConfig;
990 type Error = SklearsError;
991 type Float = f64;
992
993 fn config(&self) -> &Self::Config {
994 &self.config
995 }
996}
997
998impl Fit<Array2<f64>, Vec<usize>> for ImbalancedEnsembleClassifier<Untrained> {
999 type Fitted = ImbalancedEnsembleClassifier<Trained>;
1000
1001 #[allow(non_snake_case)]
1002 fn fit(self, X: &Array2<f64>, y: &Vec<usize>) -> SklResult<Self::Fitted> {
1003 let mut class_dist = HashMap::new();
1005 for &class in y {
1006 *class_dist.entry(class).or_insert(0) += 1;
1007 }
1008
1009 let total_samples = y.len() as f64;
1011 let n_classes = class_dist.len();
1012 let mut class_weights = HashMap::new();
1013 for (&class, &count) in &class_dist {
1014 class_weights.insert(class, total_samples / (n_classes as f64 * count as f64));
1015 }
1016
1017 let X_resampled = X.clone();
1019 let y_resampled = y.clone();
1020
1021 let bootstrap_samples = if self.config.balanced_bootstrap {
1023 let mut samples = Vec::new();
1025 for _ in 0..self.config.n_estimators {
1026 samples.push((X_resampled.clone(), y_resampled.clone()));
1027 }
1028 samples
1029 } else {
1030 vec![(X_resampled, y_resampled)]
1031 };
1032
1033 let adjusted_weights = if let Some(ref cost_config) = self.config.cost_sensitive_config {
1035 self.apply_cost_sensitive_weights(&class_weights, cost_config)?
1036 } else {
1037 class_weights.clone()
1038 };
1039
1040 let mut trained_base_classifiers = Vec::new();
1042 for (X_sample, y_sample) in bootstrap_samples {
1043 let y_sample_array = Array1::from_vec(y_sample.iter().map(|&x| x as i32).collect());
1045
1046 let classifier = if self.config.cost_sensitive_config.is_some() {
1048 BaggingClassifier::new()
1050 .n_estimators(50)
1051 .bootstrap(true)
1052 .fit(&X_sample, &y_sample_array)?
1053 } else {
1054 BaggingClassifier::new()
1055 .n_estimators(50)
1056 .bootstrap(true)
1057 .fit(&X_sample, &y_sample_array)?
1058 };
1059
1060 trained_base_classifiers.push(classifier);
1061 }
1062
1063 Ok(ImbalancedEnsembleClassifier {
1065 config: self.config,
1066 state: std::marker::PhantomData,
1067 base_classifiers: Some(trained_base_classifiers),
1068 class_weights: Some(class_weights),
1069 optimal_thresholds: Some(HashMap::new()), class_distributions: Some(class_dist),
1071 sampling_results: Some(Vec::new()),
1072 })
1073 }
1074}
1075
1076impl Predict<Array2<f64>, Vec<usize>> for ImbalancedEnsembleClassifier<Trained> {
1077 fn predict(&self, X: &Array2<f64>) -> SklResult<Vec<usize>> {
1078 let mut all_predictions = Vec::new();
1079
1080 let base_classifiers = self.base_classifiers.as_ref().expect("Model is trained");
1082 for classifier in base_classifiers {
1083 let predictions = classifier.predict(X)?;
1084 let predictions_vec: Vec<usize> = predictions.iter().map(|&x| x as usize).collect();
1085 all_predictions.push(predictions_vec);
1086 }
1087
1088 match self.config.combination_strategy {
1090 CombinationStrategy::MajorityVoting => self.majority_voting(&all_predictions),
1091 CombinationStrategy::WeightedVoting => self.weighted_voting(&all_predictions),
1092 _ => {
1093 self.majority_voting(&all_predictions)
1095 }
1096 }
1097 }
1098}
1099
1100impl<State> ImbalancedEnsembleClassifier<State> {
1101 fn majority_voting(&self, predictions: &[Vec<usize>]) -> SklResult<Vec<usize>> {
1103 if predictions.is_empty() {
1104 return Err(SklearsError::InvalidInput(
1105 "No predictions to combine".to_string(),
1106 ));
1107 }
1108
1109 let n_samples = predictions[0].len();
1110 let mut final_predictions = Vec::with_capacity(n_samples);
1111
1112 for i in 0..n_samples {
1113 let mut class_votes = HashMap::new();
1114
1115 for pred in predictions {
1116 *class_votes.entry(pred[i]).or_insert(0) += 1;
1117 }
1118
1119 let predicted_class = *class_votes
1120 .iter()
1121 .max_by_key(|(_, &count)| count)
1122 .map(|(class, _)| class)
1123 .unwrap();
1124
1125 final_predictions.push(predicted_class);
1126 }
1127
1128 Ok(final_predictions)
1129 }
1130
1131 fn weighted_voting(&self, predictions: &[Vec<usize>]) -> SklResult<Vec<usize>> {
1133 if predictions.is_empty() {
1134 return Err(SklearsError::InvalidInput(
1135 "No predictions to combine".to_string(),
1136 ));
1137 }
1138
1139 let n_samples = predictions[0].len();
1140 let mut final_predictions = Vec::with_capacity(n_samples);
1141
1142 for i in 0..n_samples {
1143 let mut class_scores = HashMap::new();
1144
1145 for (j, pred) in predictions.iter().enumerate() {
1146 let weight = 1.0 / (j + 1) as f64; *class_scores.entry(pred[i]).or_insert(0.0) += weight;
1148 }
1149
1150 let predicted_class = *class_scores
1151 .iter()
1152 .max_by(|(_, &score1), (_, &score2)| score1.partial_cmp(&score2).unwrap())
1153 .map(|(class, _)| class)
1154 .unwrap();
1155
1156 final_predictions.push(predicted_class);
1157 }
1158
1159 Ok(final_predictions)
1160 }
1161
1162 pub fn get_class_distribution(&self) -> &HashMap<usize, usize> {
1164 self.class_distributions
1165 .as_ref()
1166 .expect("Class distributions not available")
1167 }
1168
1169 pub fn get_class_weights(&self) -> &HashMap<usize, f64> {
1171 self.class_weights
1172 .as_ref()
1173 .expect("Class weights not available")
1174 }
1175
1176 pub fn get_optimal_thresholds(&self) -> &HashMap<usize, f64> {
1178 self.optimal_thresholds
1179 .as_ref()
1180 .expect("Optimal thresholds not available")
1181 }
1182
1183 fn apply_cost_sensitive_weights(
1185 &self,
1186 base_weights: &HashMap<usize, f64>,
1187 cost_config: &CostSensitiveConfig,
1188 ) -> SklResult<HashMap<usize, f64>> {
1189 let mut adjusted_weights = base_weights.clone();
1190
1191 if let Some(ref custom_weights) = cost_config.class_weights {
1193 for (&class, &weight) in custom_weights {
1194 adjusted_weights.insert(class, weight);
1195 }
1196 }
1197
1198 if cost_config.cost_matrix.nrows() > 0 {
1200 for (&class, weight) in &mut adjusted_weights {
1201 if class < cost_config.cost_matrix.nrows() {
1202 let misclassification_cost = cost_config.cost_matrix.row(class).sum();
1204 *weight *= misclassification_cost;
1205 }
1206 }
1207 }
1208
1209 if cost_config.class_balanced_weights {
1211 let total_weight: f64 = adjusted_weights.values().sum();
1212 let n_classes = adjusted_weights.len() as f64;
1213 let avg_weight = total_weight / n_classes;
1214
1215 for weight in adjusted_weights.values_mut() {
1216 *weight /= avg_weight;
1217 }
1218 }
1219
1220 Ok(adjusted_weights)
1221 }
1222}
1223
1224#[allow(non_snake_case)]
1225#[cfg(test)]
1226mod tests {
1227 use super::*;
1228 use scirs2_core::ndarray::Array2;
1229
1230 #[test]
1231 fn test_imbalanced_config() {
1232 let config = ImbalancedEnsembleConfig::builder()
1233 .n_estimators(5)
1234 .sampling_strategy(SamplingStrategy::SMOTE)
1235 .balanced_bootstrap(true)
1236 .build();
1237
1238 assert_eq!(config.n_estimators, 5);
1239 assert_eq!(config.sampling_strategy, SamplingStrategy::SMOTE);
1240 assert!(config.balanced_bootstrap);
1241 }
1242
1243 #[test]
1244 fn test_class_distribution_analysis() {
1245 let config = ImbalancedEnsembleConfig::default();
1246 let mut classifier = ImbalancedEnsembleClassifier::new(config);
1247
1248 let y = vec![0, 0, 0, 0, 1, 2]; classifier.analyze_class_distribution(&y).unwrap();
1250
1251 assert_eq!(classifier.class_distributions.as_ref().unwrap()[&0], 4);
1252 assert_eq!(classifier.class_distributions.as_ref().unwrap()[&1], 1);
1253 assert_eq!(classifier.class_distributions.as_ref().unwrap()[&2], 1);
1254
1255 assert!(
1257 classifier.class_weights.as_ref().unwrap()[&1]
1258 > classifier.class_weights.as_ref().unwrap()[&0]
1259 );
1260 assert!(
1261 classifier.class_weights.as_ref().unwrap()[&2]
1262 > classifier.class_weights.as_ref().unwrap()[&0]
1263 );
1264 }
1265
1266 #[test]
1267 #[allow(non_snake_case)]
1268 fn test_smote_sampler() {
1269 let config = SMOTEConfig {
1270 k_neighbors: 1, sampling_strategy: 1.0,
1272 random_state: Some(42),
1273 selective: false,
1274 borderline_mode: BorderlineMode::Borderline1,
1275 };
1276
1277 let mut sampler = SMOTESampler::new(config);
1278
1279 let X = Array2::from_shape_vec(
1281 (6, 2),
1282 vec![1.0, 1.0, 1.1, 1.1, 1.2, 1.2, 1.3, 1.3, 5.0, 5.0, 5.1, 5.1],
1283 )
1284 .unwrap();
1285 let y = vec![0, 0, 0, 0, 1, 1]; let (X_resampled, y_resampled) = sampler.fit_resample(&X, &y).unwrap();
1288
1289 let class_1_count = y_resampled.iter().filter(|&&c| c == 1).count();
1291 assert!(class_1_count > 2); assert_eq!(X_resampled.shape()[1], X.shape()[1]); assert_eq!(X_resampled.shape()[0], y_resampled.len()); }
1297
1298 #[test]
1299 fn test_imbalanced_ensemble_basic() {
1300 let config = ImbalancedEnsembleConfig::builder()
1301 .n_estimators(3)
1302 .sampling_strategy(SamplingStrategy::RandomOverSampling)
1303 .random_state(42)
1304 .build();
1305
1306 let classifier = ImbalancedEnsembleClassifier::new(config);
1307
1308 assert_eq!(classifier.config.n_estimators, 3);
1310 assert_eq!(
1311 classifier.config.sampling_strategy,
1312 SamplingStrategy::RandomOverSampling
1313 );
1314 assert!(classifier.base_classifiers.is_none());
1316 }
1317
1318 #[test]
1319 fn test_cost_sensitive_ensemble() {
1320 let cost_matrix = Array2::from_shape_vec((2, 2), vec![1.0, 3.0, 1.0, 1.0]).unwrap();
1322 let classifier = ImbalancedEnsembleClassifier::cost_sensitive(cost_matrix);
1323
1324 assert!(classifier.config.cost_sensitive_config.is_some());
1326 let cost_config = classifier.config.cost_sensitive_config.as_ref().unwrap();
1327 assert_eq!(cost_config.cost_matrix.shape(), &[2, 2]);
1328 assert!(cost_config.class_balanced_weights);
1329 assert_eq!(
1330 cost_config.algorithm,
1331 CostSensitiveAlgorithm::CostSensitiveBoosting
1332 );
1333 }
1334
1335 #[test]
1336 fn test_cost_sensitive_weights() {
1337 let mut class_weights = HashMap::new();
1338 class_weights.insert(0, 1.0);
1339 class_weights.insert(1, 3.0); let classifier =
1342 ImbalancedEnsembleClassifier::cost_sensitive_weights(class_weights.clone());
1343
1344 assert!(classifier.config.cost_sensitive_config.is_some());
1346 let cost_config = classifier.config.cost_sensitive_config.as_ref().unwrap();
1347 assert_eq!(cost_config.class_weights, Some(class_weights));
1348 assert!(!cost_config.class_balanced_weights);
1349 }
1350
1351 #[test]
1352 fn test_smote_ensemble_creation() {
1353 let classifier = ImbalancedEnsembleClassifier::smote_ensemble(3);
1354
1355 assert_eq!(classifier.config.sampling_strategy, SamplingStrategy::SMOTE);
1357 assert!(classifier.config.smote_config.is_some());
1358 let smote_config = classifier.config.smote_config.as_ref().unwrap();
1359 assert_eq!(smote_config.k_neighbors, 3);
1360 assert!(classifier.config.balanced_bootstrap);
1361 }
1362}