sklears_ensemble/
imbalanced.rs

1//! Imbalanced Learning Ensemble Methods
2//!
3//! This module provides ensemble methods specifically designed for imbalanced datasets
4//! where class distributions are skewed. It includes various sampling strategies,
5//! cost-sensitive learning, and specialized ensemble techniques.
6
7use crate::bagging::BaggingClassifier;
8// ❌ REMOVED: rand_chacha::rand_core - use scirs2_core::random instead
9// ❌ REMOVED: rand_chacha::scirs2_core::random::rngs::StdRng - use scirs2_core::random instead
10use scirs2_core::ndarray::{Array1, Array2};
11use sklears_core::{
12    error::{Result as SklResult, SklearsError},
13    prelude::Predict,
14    traits::{Estimator, Fit, Trained, Untrained},
15};
16use std::collections::HashMap;
17
18/// Helper function to generate random f64 from scirs2_core::random::RngCore
19fn gen_f64(rng: &mut impl scirs2_core::random::RngCore) -> f64 {
20    let mut bytes = [0u8; 8];
21    rng.fill_bytes(&mut bytes);
22    f64::from_le_bytes(bytes) / f64::from_le_bytes([255u8; 8])
23}
24
25/// Helper function to generate random value in range from scirs2_core::random::RngCore
26fn gen_range_usize(
27    rng: &mut impl scirs2_core::random::RngCore,
28    range: std::ops::Range<usize>,
29) -> usize {
30    let mut bytes = [0u8; 8];
31    rng.fill_bytes(&mut bytes);
32    let val = u64::from_le_bytes(bytes);
33    range.start + (val as usize % (range.end - range.start))
34}
35
36/// Configuration for imbalanced learning ensemble methods
37#[derive(Debug, Clone)]
38pub struct ImbalancedEnsembleConfig {
39    /// Number of base estimators
40    pub n_estimators: usize,
41    /// Sampling strategy for handling imbalance
42    pub sampling_strategy: SamplingStrategy,
43    /// Cost-sensitive learning configuration
44    pub cost_sensitive_config: Option<CostSensitiveConfig>,
45    /// Ensemble combination strategy
46    pub combination_strategy: CombinationStrategy,
47    /// Whether to use class-balanced bootstrap sampling
48    pub balanced_bootstrap: bool,
49    /// Threshold moving strategy
50    pub threshold_moving: Option<ThresholdMovingStrategy>,
51    /// Under-sampling ratio
52    pub under_sampling_ratio: f64,
53    /// Over-sampling ratio
54    pub over_sampling_ratio: f64,
55    /// SMOTE parameters for synthetic oversampling
56    pub smote_config: Option<SMOTEConfig>,
57    /// Random seed for reproducibility
58    pub random_state: Option<u64>,
59}
60
61impl Default for ImbalancedEnsembleConfig {
62    fn default() -> Self {
63        Self {
64            n_estimators: 10,
65            sampling_strategy: SamplingStrategy::SMOTE,
66            cost_sensitive_config: None,
67            combination_strategy: CombinationStrategy::WeightedVoting,
68            balanced_bootstrap: true,
69            threshold_moving: Some(ThresholdMovingStrategy::Youden),
70            under_sampling_ratio: 0.5,
71            over_sampling_ratio: 1.0,
72            smote_config: Some(SMOTEConfig::default()),
73            random_state: None,
74        }
75    }
76}
77
78/// Strategies for handling class imbalance through sampling
79#[derive(Debug, Clone, PartialEq)]
80pub enum SamplingStrategy {
81    /// No sampling (baseline)
82    None,
83    /// Random under-sampling of majority class
84    RandomUnderSampling,
85    /// Random over-sampling of minority class
86    RandomOverSampling,
87    /// SMOTE (Synthetic Minority Oversampling Technique)
88    SMOTE,
89    /// ADASYN (Adaptive Synthetic Sampling)
90    ADASYN,
91    /// BorderlineSMOTE
92    BorderlineSMOTE,
93    /// SVMSMOTE
94    SVMSMOTE,
95    /// Edited Nearest Neighbors under-sampling
96    EditedNearestNeighbors,
97    /// Tomek Links removal
98    TomekLinks,
99    /// Neighborhood Cleaning Rule
100    NeighborhoodCleaning,
101    /// SMOTEENN (SMOTE + Edited Nearest Neighbors)
102    SMOTEENN,
103    /// SMOTETomek (SMOTE + Tomek Links)
104    SMOTETomek,
105}
106
107/// Configuration for cost-sensitive learning
108#[derive(Debug, Clone)]
109pub struct CostSensitiveConfig {
110    /// Cost matrix for different class misclassifications
111    pub cost_matrix: Array2<f64>,
112    /// Whether to use class-balanced weights
113    pub class_balanced_weights: bool,
114    /// Custom class weights
115    pub class_weights: Option<HashMap<usize, f64>>,
116    /// Cost-sensitive learning algorithm
117    pub algorithm: CostSensitiveAlgorithm,
118}
119
120/// Cost-sensitive learning algorithms
121#[derive(Debug, Clone, PartialEq)]
122pub enum CostSensitiveAlgorithm {
123    CostSensitiveDecisionTree,
124    MetaCost,
125    CostSensitiveBoosting,
126    ThresholdMoving,
127}
128
129/// Ensemble combination strategies for imbalanced data
130#[derive(Debug, Clone, PartialEq)]
131pub enum CombinationStrategy {
132    /// Simple majority voting
133    MajorityVoting,
134    /// Weighted voting based on class performance
135    WeightedVoting,
136    /// Stacking with imbalanced-aware meta-learner
137    ImbalancedStacking,
138    /// Dynamic selection based on local class distribution
139    DynamicSelection,
140    /// Bayesian combination with class priors
141    BayesianCombination,
142}
143
144/// Threshold moving strategies for imbalanced classification
145#[derive(Debug, Clone, PartialEq)]
146pub enum ThresholdMovingStrategy {
147    /// Maximize Youden's J statistic
148    Youden,
149    /// Maximize F1 score
150    F1Optimal,
151    /// Maximize precision-recall AUC
152    PrecisionRecallOptimal,
153    /// Cost-sensitive threshold
154    CostSensitive,
155    /// Maximize balanced accuracy
156    BalancedAccuracy,
157}
158
159/// SMOTE configuration parameters
160#[derive(Debug, Clone)]
161pub struct SMOTEConfig {
162    /// Number of nearest neighbors for SMOTE
163    pub k_neighbors: usize,
164    /// Sampling strategy (controls amount of oversampling)
165    pub sampling_strategy: f64,
166    /// Random state for reproducibility
167    pub random_state: Option<u64>,
168    /// Whether to use selective SMOTE
169    pub selective: bool,
170    /// Borderline SMOTE mode
171    pub borderline_mode: BorderlineMode,
172}
173
174impl Default for SMOTEConfig {
175    fn default() -> Self {
176        Self {
177            k_neighbors: 5,
178            sampling_strategy: 1.0,
179            random_state: None,
180            selective: false,
181            borderline_mode: BorderlineMode::Borderline1,
182        }
183    }
184}
185
186/// Borderline SMOTE variants
187#[derive(Debug, Clone, PartialEq)]
188pub enum BorderlineMode {
189    /// Borderline-1 SMOTE
190    Borderline1,
191    /// Borderline-2 SMOTE
192    Borderline2,
193}
194
195/// Imbalanced ensemble classifier using specialized techniques
196pub struct ImbalancedEnsembleClassifier<State = Untrained> {
197    config: ImbalancedEnsembleConfig,
198    state: std::marker::PhantomData<State>,
199    // Fitted attributes - only populated after training
200    base_classifiers: Option<Vec<BaggingClassifier<Trained>>>,
201    class_weights: Option<HashMap<usize, f64>>,
202    optimal_thresholds: Option<HashMap<usize, f64>>,
203    class_distributions: Option<HashMap<usize, usize>>,
204    sampling_results: Option<Vec<SamplingResult>>,
205}
206
207/// Results from sampling operations
208#[derive(Debug, Clone)]
209pub struct SamplingResult {
210    /// Original class distribution
211    pub original_distribution: HashMap<usize, usize>,
212    /// Resampled class distribution
213    pub resampled_distribution: HashMap<usize, usize>,
214    /// Sampling quality metrics
215    pub quality_metrics: SamplingQualityMetrics,
216}
217
218/// Quality metrics for sampling operations
219#[derive(Debug, Clone)]
220pub struct SamplingQualityMetrics {
221    /// Balance ratio after sampling
222    pub balance_ratio: f64,
223    /// Information preservation score
224    pub information_preservation: f64,
225    /// Diversity increase
226    pub diversity_increase: f64,
227    /// Computational overhead
228    pub computational_overhead: f64,
229}
230
231/// SMOTE synthetic sample generator
232pub struct SMOTESampler {
233    config: SMOTEConfig,
234    rng: scirs2_core::random::CoreRandom<scirs2_core::random::rngs::StdRng>,
235}
236
237impl ImbalancedEnsembleConfig {
238    pub fn builder() -> ImbalancedEnsembleConfigBuilder {
239        ImbalancedEnsembleConfigBuilder::default()
240    }
241}
242
243#[derive(Default)]
244pub struct ImbalancedEnsembleConfigBuilder {
245    config: ImbalancedEnsembleConfig,
246}
247
248impl ImbalancedEnsembleConfigBuilder {
249    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
250        self.config.n_estimators = n_estimators;
251        self
252    }
253
254    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
255        self.config.sampling_strategy = strategy;
256        self
257    }
258
259    pub fn cost_sensitive_config(mut self, config: CostSensitiveConfig) -> Self {
260        self.config.cost_sensitive_config = Some(config);
261        self
262    }
263
264    pub fn combination_strategy(mut self, strategy: CombinationStrategy) -> Self {
265        self.config.combination_strategy = strategy;
266        self
267    }
268
269    pub fn balanced_bootstrap(mut self, balanced: bool) -> Self {
270        self.config.balanced_bootstrap = balanced;
271        self
272    }
273
274    pub fn threshold_moving(mut self, strategy: ThresholdMovingStrategy) -> Self {
275        self.config.threshold_moving = Some(strategy);
276        self
277    }
278
279    pub fn smote_config(mut self, config: SMOTEConfig) -> Self {
280        self.config.smote_config = Some(config);
281        self
282    }
283
284    pub fn random_state(mut self, seed: u64) -> Self {
285        self.config.random_state = Some(seed);
286        self
287    }
288
289    pub fn build(self) -> ImbalancedEnsembleConfig {
290        self.config
291    }
292}
293
294impl ImbalancedEnsembleClassifier<Untrained> {
295    pub fn new(config: ImbalancedEnsembleConfig) -> Self {
296        Self {
297            config,
298            state: std::marker::PhantomData,
299            base_classifiers: None,
300            class_weights: Some(HashMap::new()),
301            optimal_thresholds: Some(HashMap::new()),
302            class_distributions: Some(HashMap::new()),
303            sampling_results: Some(Vec::new()),
304        }
305    }
306
307    pub fn builder() -> ImbalancedEnsembleClassifierBuilder {
308        ImbalancedEnsembleClassifierBuilder::new()
309    }
310
311    /// Analyze class distribution in the dataset
312    fn analyze_class_distribution(&mut self, y: &[usize]) -> SklResult<()> {
313        self.class_distributions.as_mut().unwrap().clear();
314
315        for &class in y {
316            *self
317                .class_distributions
318                .as_mut()
319                .unwrap()
320                .entry(class)
321                .or_insert(0) += 1;
322        }
323
324        // Calculate class weights for imbalanced data
325        let total_samples = y.len();
326        let n_classes = self.class_distributions.as_ref().unwrap().len();
327
328        for (&class, &count) in self.class_distributions.as_ref().unwrap() {
329            let weight = total_samples as f64 / (n_classes as f64 * count as f64);
330            self.class_weights.as_mut().unwrap().insert(class, weight);
331        }
332
333        Ok(())
334    }
335
336    /// Apply sampling strategy to balance the dataset
337    fn apply_sampling_strategy(
338        &mut self,
339        X: &Array2<f64>,
340        y: &[usize],
341    ) -> SklResult<(Array2<f64>, Vec<usize>)> {
342        match self.config.sampling_strategy {
343            SamplingStrategy::None => Ok((X.clone(), y.to_vec())),
344            SamplingStrategy::RandomUnderSampling => self.random_under_sampling(X, y),
345            SamplingStrategy::RandomOverSampling => self.random_over_sampling(X, y),
346            SamplingStrategy::SMOTE => self.smote_sampling(X, y),
347            SamplingStrategy::ADASYN => self.adasyn_sampling(X, y),
348            SamplingStrategy::TomekLinks => self.tomek_links_sampling(X, y),
349            SamplingStrategy::SMOTEENN => self.smoteenn_sampling(X, y),
350            _ => {
351                // Default to SMOTE for unimplemented strategies
352                self.smote_sampling(X, y)
353            }
354        }
355    }
356
357    /// Random under-sampling of majority class
358    fn random_under_sampling(
359        &self,
360        X: &Array2<f64>,
361        y: &[usize],
362    ) -> SklResult<(Array2<f64>, Vec<usize>)> {
363        let mut rng = if let Some(seed) = self.config.random_state {
364            scirs2_core::random::seeded_rng(seed)
365        } else {
366            scirs2_core::random::seeded_rng(42)
367        };
368
369        // Find minority class size
370        let min_class_size = self
371            .class_distributions
372            .as_ref()
373            .unwrap()
374            .values()
375            .min()
376            .copied()
377            .unwrap_or(0);
378        let target_size =
379            (min_class_size as f64 * (1.0 + self.config.under_sampling_ratio)) as usize;
380
381        let mut resampled_indices = Vec::new();
382
383        for (&class, &count) in self.class_distributions.as_ref().unwrap() {
384            let class_indices: Vec<usize> = y
385                .iter()
386                .enumerate()
387                .filter(|(_, &c)| c == class)
388                .map(|(i, _)| i)
389                .collect();
390
391            let sample_size = if count > target_size {
392                target_size
393            } else {
394                count
395            };
396
397            // Randomly sample indices
398            let mut selected_indices = class_indices;
399            selected_indices.truncate(sample_size);
400
401            // Shuffle for randomness
402            for i in (1..selected_indices.len()).rev() {
403                let j = gen_range_usize(&mut rng, 0..(i + 1));
404                selected_indices.swap(i, j);
405            }
406
407            resampled_indices.extend(selected_indices);
408        }
409
410        // Create resampled arrays
411        let n_features = X.shape()[1];
412        let mut resampled_X = Vec::with_capacity(resampled_indices.len() * n_features);
413        let mut resampled_y = Vec::with_capacity(resampled_indices.len());
414
415        for &idx in &resampled_indices {
416            for j in 0..n_features {
417                resampled_X.push(X[[idx, j]]);
418            }
419            resampled_y.push(y[idx]);
420        }
421
422        let X_resampled =
423            Array2::from_shape_vec((resampled_indices.len(), n_features), resampled_X)?;
424
425        Ok((X_resampled, resampled_y))
426    }
427
428    /// Random over-sampling of minority class
429    fn random_over_sampling(
430        &self,
431        X: &Array2<f64>,
432        y: &[usize],
433    ) -> SklResult<(Array2<f64>, Vec<usize>)> {
434        let mut rng = if let Some(seed) = self.config.random_state {
435            scirs2_core::random::seeded_rng(seed)
436        } else {
437            scirs2_core::random::seeded_rng(42)
438        };
439
440        // Find majority class size
441        let max_class_size = self
442            .class_distributions
443            .as_ref()
444            .unwrap()
445            .values()
446            .max()
447            .copied()
448            .unwrap_or(0);
449        let target_size = (max_class_size as f64 * self.config.over_sampling_ratio) as usize;
450
451        let mut resampled_X = Vec::new();
452        let mut resampled_y = Vec::new();
453
454        for (&class, &count) in self.class_distributions.as_ref().unwrap() {
455            let class_indices: Vec<usize> = y
456                .iter()
457                .enumerate()
458                .filter(|(_, &c)| c == class)
459                .map(|(i, _)| i)
460                .collect();
461
462            // Add original samples
463            for &idx in &class_indices {
464                for j in 0..X.shape()[1] {
465                    resampled_X.push(X[[idx, j]]);
466                }
467                resampled_y.push(class);
468            }
469
470            // Add synthetic samples if needed
471            if count < target_size {
472                let additional_samples = target_size - count;
473                for _ in 0..additional_samples {
474                    let random_idx =
475                        class_indices[gen_range_usize(&mut rng, 0..class_indices.len())];
476                    for j in 0..X.shape()[1] {
477                        resampled_X.push(X[[random_idx, j]]);
478                    }
479                    resampled_y.push(class);
480                }
481            }
482        }
483
484        let n_features = X.shape()[1];
485        let n_samples = resampled_y.len();
486        let X_resampled = Array2::from_shape_vec((n_samples, n_features), resampled_X)?;
487
488        Ok((X_resampled, resampled_y))
489    }
490
491    /// SMOTE (Synthetic Minority Oversampling Technique)
492    fn smote_sampling(&self, X: &Array2<f64>, y: &[usize]) -> SklResult<(Array2<f64>, Vec<usize>)> {
493        let default_config = SMOTEConfig::default();
494        let smote_config = self.config.smote_config.as_ref().unwrap_or(&default_config);
495
496        let mut sampler = SMOTESampler::new(smote_config.clone());
497        sampler.fit_resample(X, y)
498    }
499
500    /// ADASYN (Adaptive Synthetic Sampling)
501    fn adasyn_sampling(
502        &self,
503        X: &Array2<f64>,
504        y: &[usize],
505    ) -> SklResult<(Array2<f64>, Vec<usize>)> {
506        // For now, delegate to SMOTE with adaptive parameters
507        // In practice, ADASYN would calculate density ratios for each minority sample
508        self.smote_sampling(X, y)
509    }
510
511    /// Tomek Links removal for cleaning
512    #[allow(non_snake_case)]
513    fn tomek_links_sampling(
514        &self,
515        X: &Array2<f64>,
516        y: &[usize],
517    ) -> SklResult<(Array2<f64>, Vec<usize>)> {
518        // Simplified Tomek Links implementation
519        // In practice, this would identify and remove Tomek Links (nearest neighbors of different classes)
520
521        let mut keep_indices = Vec::new();
522
523        for i in 0..X.shape()[0] {
524            let mut nearest_distance = f64::INFINITY;
525            let mut nearest_class = y[i];
526
527            // Find nearest neighbor
528            for j in 0..X.shape()[0] {
529                if i != j {
530                    let distance = self.euclidean_distance(X, i, j);
531                    if distance < nearest_distance {
532                        nearest_distance = distance;
533                        nearest_class = y[j];
534                    }
535                }
536            }
537
538            // Keep sample if it's not part of a Tomek Link
539            if nearest_class == y[i] {
540                keep_indices.push(i);
541            }
542        }
543
544        // Create cleaned dataset
545        let n_features = X.shape()[1];
546        let mut cleaned_X = Vec::with_capacity(keep_indices.len() * n_features);
547        let mut cleaned_y = Vec::with_capacity(keep_indices.len());
548
549        for &idx in &keep_indices {
550            for j in 0..n_features {
551                cleaned_X.push(X[[idx, j]]);
552            }
553            cleaned_y.push(y[idx]);
554        }
555
556        let X_cleaned = Array2::from_shape_vec((keep_indices.len(), n_features), cleaned_X)?;
557
558        Ok((X_cleaned, cleaned_y))
559    }
560
561    /// SMOTEENN (SMOTE + Edited Nearest Neighbors)
562    fn smoteenn_sampling(
563        &self,
564        X: &Array2<f64>,
565        y: &[usize],
566    ) -> SklResult<(Array2<f64>, Vec<usize>)> {
567        // First apply SMOTE
568        let (X_smote, y_smote) = self.smote_sampling(X, y)?;
569
570        // Then apply Edited Nearest Neighbors cleaning
571        self.edited_nearest_neighbors_cleaning(&X_smote, &y_smote)
572    }
573
574    /// Edited Nearest Neighbors cleaning
575    #[allow(non_snake_case)]
576    fn edited_nearest_neighbors_cleaning(
577        &self,
578        X: &Array2<f64>,
579        y: &[usize],
580    ) -> SklResult<(Array2<f64>, Vec<usize>)> {
581        let k = 3; // Number of neighbors to consider
582        let mut keep_indices = Vec::new();
583
584        for i in 0..X.shape()[0] {
585            let neighbors = self.find_k_nearest_neighbors(X, i, k);
586            let neighbor_classes: Vec<usize> = neighbors.iter().map(|&idx| y[idx]).collect();
587
588            // Keep sample if majority of neighbors have the same class
589            let same_class_count = neighbor_classes.iter().filter(|&&c| c == y[i]).count();
590
591            if same_class_count > neighbors.len() / 2 {
592                keep_indices.push(i);
593            }
594        }
595
596        // Create cleaned dataset
597        let n_features = X.shape()[1];
598        let mut cleaned_X = Vec::with_capacity(keep_indices.len() * n_features);
599        let mut cleaned_y = Vec::with_capacity(keep_indices.len());
600
601        for &idx in &keep_indices {
602            for j in 0..n_features {
603                cleaned_X.push(X[[idx, j]]);
604            }
605            cleaned_y.push(y[idx]);
606        }
607
608        let X_cleaned = Array2::from_shape_vec((keep_indices.len(), n_features), cleaned_X)?;
609
610        Ok((X_cleaned, cleaned_y))
611    }
612
613    /// Find k nearest neighbors for a given sample
614    fn find_k_nearest_neighbors(&self, X: &Array2<f64>, sample_idx: usize, k: usize) -> Vec<usize> {
615        let mut distances: Vec<(f64, usize)> = Vec::new();
616
617        for i in 0..X.shape()[0] {
618            if i != sample_idx {
619                let distance = self.euclidean_distance(X, sample_idx, i);
620                distances.push((distance, i));
621            }
622        }
623
624        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
625        distances.iter().take(k).map(|(_, idx)| *idx).collect()
626    }
627
628    /// Calculate Euclidean distance between two samples
629    fn euclidean_distance(&self, X: &Array2<f64>, i: usize, j: usize) -> f64 {
630        let mut sum_squared = 0.0;
631        for k in 0..X.shape()[1] {
632            let diff = X[[i, k]] - X[[j, k]];
633            sum_squared += diff * diff;
634        }
635        sum_squared.sqrt()
636    }
637
638    /// Optimize classification thresholds for imbalanced data
639    fn optimize_thresholds(&mut self, X: &Array2<f64>, y: &[usize]) -> SklResult<()> {
640        if let Some(ref strategy) = self.config.threshold_moving {
641            match strategy {
642                ThresholdMovingStrategy::Youden => {
643                    self.optimize_youden_threshold(X, y)?;
644                }
645                ThresholdMovingStrategy::F1Optimal => {
646                    self.optimize_f1_threshold(X, y)?;
647                }
648                _ => {
649                    // Default to 0.5 for all classes
650                    for &class in self.class_distributions.as_ref().unwrap().keys() {
651                        self.optimal_thresholds.as_mut().unwrap().insert(class, 0.5);
652                    }
653                }
654            }
655        }
656        Ok(())
657    }
658
659    /// Optimize threshold using Youden's J statistic
660    fn optimize_youden_threshold(&mut self, _X: &Array2<f64>, _y: &[usize]) -> SklResult<()> {
661        // Simplified implementation - in practice would use ROC analysis
662        for &class in self.class_distributions.as_ref().unwrap().keys() {
663            self.optimal_thresholds.as_mut().unwrap().insert(class, 0.5);
664        }
665        Ok(())
666    }
667
668    /// Optimize threshold for F1 score
669    fn optimize_f1_threshold(&mut self, _X: &Array2<f64>, _y: &[usize]) -> SklResult<()> {
670        // Simplified implementation - in practice would optimize F1 score
671        for &class in self.class_distributions.as_ref().unwrap().keys() {
672            self.optimal_thresholds.as_mut().unwrap().insert(class, 0.5);
673        }
674        Ok(())
675    }
676
677    /// Create balanced bootstrap samples
678    #[allow(non_snake_case)]
679    fn create_balanced_bootstrap(
680        &self,
681        X: &Array2<f64>,
682        y: &[usize],
683    ) -> SklResult<Vec<(Array2<f64>, Vec<usize>)>> {
684        let mut bootstrap_samples = Vec::new();
685        let mut rng = if let Some(seed) = self.config.random_state {
686            scirs2_core::random::seeded_rng(seed)
687        } else {
688            scirs2_core::random::seeded_rng(42)
689        };
690
691        for _ in 0..self.config.n_estimators {
692            let mut sample_indices = Vec::new();
693
694            // Equal sampling from each class
695            let samples_per_class = X.shape()[0] / self.class_distributions.as_ref().unwrap().len();
696
697            for &class in self.class_distributions.as_ref().unwrap().keys() {
698                let class_indices: Vec<usize> = y
699                    .iter()
700                    .enumerate()
701                    .filter(|(_, &c)| c == class)
702                    .map(|(i, _)| i)
703                    .collect();
704
705                // Bootstrap sample from this class
706                for _ in 0..samples_per_class {
707                    let random_idx =
708                        class_indices[gen_range_usize(&mut rng, 0..class_indices.len())];
709                    sample_indices.push(random_idx);
710                }
711            }
712
713            // Create bootstrap sample
714            let n_features = X.shape()[1];
715            let mut sample_X = Vec::with_capacity(sample_indices.len() * n_features);
716            let mut sample_y = Vec::with_capacity(sample_indices.len());
717
718            for &idx in &sample_indices {
719                for j in 0..n_features {
720                    sample_X.push(X[[idx, j]]);
721                }
722                sample_y.push(y[idx]);
723            }
724
725            let X_sample = Array2::from_shape_vec((sample_indices.len(), n_features), sample_X)?;
726            bootstrap_samples.push((X_sample, sample_y));
727        }
728
729        Ok(bootstrap_samples)
730    }
731
732    /// Create cost-sensitive ensemble with custom cost matrix
733    pub fn cost_sensitive(cost_matrix: Array2<f64>) -> Self {
734        let cost_config = CostSensitiveConfig {
735            cost_matrix,
736            class_balanced_weights: true,
737            class_weights: None,
738            algorithm: CostSensitiveAlgorithm::CostSensitiveBoosting,
739        };
740
741        let config = ImbalancedEnsembleConfig {
742            cost_sensitive_config: Some(cost_config),
743            combination_strategy: CombinationStrategy::WeightedVoting,
744            ..Default::default()
745        };
746
747        Self::new(config)
748    }
749
750    /// Create cost-sensitive ensemble with class weights
751    pub fn cost_sensitive_weights(class_weights: HashMap<usize, f64>) -> Self {
752        let cost_config = CostSensitiveConfig {
753            cost_matrix: Array2::zeros((0, 0)),
754            class_balanced_weights: false,
755            class_weights: Some(class_weights),
756            algorithm: CostSensitiveAlgorithm::CostSensitiveBoosting,
757        };
758
759        let config = ImbalancedEnsembleConfig {
760            cost_sensitive_config: Some(cost_config),
761            combination_strategy: CombinationStrategy::WeightedVoting,
762            ..Default::default()
763        };
764
765        Self::new(config)
766    }
767
768    /// Create ensemble with SMOTE oversampling
769    pub fn smote_ensemble(k_neighbors: usize) -> Self {
770        let smote_config = SMOTEConfig {
771            k_neighbors,
772            sampling_strategy: 1.0,
773            random_state: None,
774            selective: false,
775            borderline_mode: BorderlineMode::Borderline1,
776        };
777
778        let config = ImbalancedEnsembleConfig {
779            sampling_strategy: SamplingStrategy::SMOTE,
780            smote_config: Some(smote_config),
781            balanced_bootstrap: true,
782            ..Default::default()
783        };
784
785        Self::new(config)
786    }
787}
788
789impl SMOTESampler {
790    pub fn new(config: SMOTEConfig) -> Self {
791        let rng = if let Some(seed) = config.random_state {
792            scirs2_core::random::seeded_rng(seed)
793        } else {
794            scirs2_core::random::seeded_rng(42)
795        };
796
797        Self { config, rng }
798    }
799
800    /// Fit and resample the dataset using SMOTE
801    #[allow(non_snake_case)]
802    pub fn fit_resample(
803        &mut self,
804        X: &Array2<f64>,
805        y: &[usize],
806    ) -> SklResult<(Array2<f64>, Vec<usize>)> {
807        // Calculate class distributions
808        let mut class_counts = HashMap::new();
809        for &class in y {
810            *class_counts.entry(class).or_insert(0) += 1;
811        }
812
813        // Find minority and majority classes
814        let max_count = *class_counts.values().max().unwrap_or(&0);
815        let target_count = (max_count as f64 * self.config.sampling_strategy) as usize;
816
817        let mut resampled_X = Vec::new();
818        let mut resampled_y = Vec::new();
819
820        // Add original samples
821        for i in 0..X.shape()[0] {
822            for j in 0..X.shape()[1] {
823                resampled_X.push(X[[i, j]]);
824            }
825            resampled_y.push(y[i]);
826        }
827
828        // Generate synthetic samples for minority classes
829        for (&class, &count) in &class_counts {
830            if count < target_count {
831                let n_synthetic = target_count - count;
832                let synthetic_samples =
833                    self.generate_synthetic_samples(X, y, class, n_synthetic)?;
834
835                for sample in synthetic_samples {
836                    resampled_X.extend(sample);
837                    resampled_y.push(class);
838                }
839            }
840        }
841
842        let n_features = X.shape()[1];
843        let n_samples = resampled_y.len();
844        let X_resampled = Array2::from_shape_vec((n_samples, n_features), resampled_X)?;
845
846        Ok((X_resampled, resampled_y))
847    }
848
849    /// Generate synthetic samples for a minority class
850    fn generate_synthetic_samples(
851        &mut self,
852        X: &Array2<f64>,
853        y: &[usize],
854        target_class: usize,
855        n_samples: usize,
856    ) -> SklResult<Vec<Vec<f64>>> {
857        // Get samples of target class
858        let class_indices: Vec<usize> = y
859            .iter()
860            .enumerate()
861            .filter(|(_, &c)| c == target_class)
862            .map(|(i, _)| i)
863            .collect();
864
865        if class_indices.len() < self.config.k_neighbors {
866            return Err(SklearsError::InvalidInput(format!(
867                "Not enough samples of class {} for SMOTE",
868                target_class
869            )));
870        }
871
872        let mut synthetic_samples = Vec::new();
873
874        for _ in 0..n_samples {
875            // Randomly select a sample from the minority class
876            let sample_idx = class_indices[gen_range_usize(&mut self.rng, 0..class_indices.len())];
877
878            // Find k nearest neighbors of the same class
879            let neighbors = self.find_nearest_neighbors(X, &class_indices, sample_idx)?;
880
881            // Randomly select one of the k nearest neighbors
882            let neighbor_idx = neighbors[gen_range_usize(&mut self.rng, 0..neighbors.len())];
883
884            // Generate synthetic sample
885            let synthetic_sample = self.generate_sample_between(X, sample_idx, neighbor_idx);
886            synthetic_samples.push(synthetic_sample);
887        }
888
889        Ok(synthetic_samples)
890    }
891
892    /// Find k nearest neighbors within the same class
893    fn find_nearest_neighbors(
894        &self,
895        X: &Array2<f64>,
896        class_indices: &[usize],
897        sample_idx: usize,
898    ) -> SklResult<Vec<usize>> {
899        let mut distances: Vec<(f64, usize)> = Vec::new();
900
901        for &idx in class_indices {
902            if idx != sample_idx {
903                let distance = self.euclidean_distance(X, sample_idx, idx);
904                distances.push((distance, idx));
905            }
906        }
907
908        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
909        let neighbors = distances
910            .iter()
911            .take(self.config.k_neighbors)
912            .map(|(_, idx)| *idx)
913            .collect();
914
915        Ok(neighbors)
916    }
917
918    /// Generate a synthetic sample between two existing samples
919    fn generate_sample_between(&mut self, X: &Array2<f64>, idx1: usize, idx2: usize) -> Vec<f64> {
920        let mut synthetic_sample = Vec::new();
921
922        for j in 0..X.shape()[1] {
923            let x1 = X[[idx1, j]];
924            let x2 = X[[idx2, j]];
925            let random_factor = gen_f64(&mut self.rng);
926
927            // Linear interpolation between the two samples
928            let synthetic_value = x1 + random_factor * (x2 - x1);
929            synthetic_sample.push(synthetic_value);
930        }
931
932        synthetic_sample
933    }
934
935    /// Calculate Euclidean distance between two samples
936    fn euclidean_distance(&self, X: &Array2<f64>, i: usize, j: usize) -> f64 {
937        let mut sum_squared = 0.0;
938        for k in 0..X.shape()[1] {
939            let diff = X[[i, k]] - X[[j, k]];
940            sum_squared += diff * diff;
941        }
942        sum_squared.sqrt()
943    }
944}
945
946pub struct ImbalancedEnsembleClassifierBuilder {
947    config: ImbalancedEnsembleConfig,
948}
949
950impl Default for ImbalancedEnsembleClassifierBuilder {
951    fn default() -> Self {
952        Self::new()
953    }
954}
955
956impl ImbalancedEnsembleClassifierBuilder {
957    pub fn new() -> Self {
958        Self {
959            config: ImbalancedEnsembleConfig::default(),
960        }
961    }
962
963    pub fn config(mut self, config: ImbalancedEnsembleConfig) -> Self {
964        self.config = config;
965        self
966    }
967
968    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
969        self.config.n_estimators = n_estimators;
970        self
971    }
972
973    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
974        self.config.sampling_strategy = strategy;
975        self
976    }
977
978    pub fn balanced_bootstrap(mut self, balanced: bool) -> Self {
979        self.config.balanced_bootstrap = balanced;
980        self
981    }
982
983    pub fn build(self) -> ImbalancedEnsembleClassifier<Untrained> {
984        ImbalancedEnsembleClassifier::new(self.config)
985    }
986}
987
988impl Estimator for ImbalancedEnsembleClassifier<Untrained> {
989    type Config = ImbalancedEnsembleConfig;
990    type Error = SklearsError;
991    type Float = f64;
992
993    fn config(&self) -> &Self::Config {
994        &self.config
995    }
996}
997
998impl Fit<Array2<f64>, Vec<usize>> for ImbalancedEnsembleClassifier<Untrained> {
999    type Fitted = ImbalancedEnsembleClassifier<Trained>;
1000
1001    #[allow(non_snake_case)]
1002    fn fit(self, X: &Array2<f64>, y: &Vec<usize>) -> SklResult<Self::Fitted> {
1003        // Analyze class distribution
1004        let mut class_dist = HashMap::new();
1005        for &class in y {
1006            *class_dist.entry(class).or_insert(0) += 1;
1007        }
1008
1009        // Calculate class weights
1010        let total_samples = y.len() as f64;
1011        let n_classes = class_dist.len();
1012        let mut class_weights = HashMap::new();
1013        for (&class, &count) in &class_dist {
1014            class_weights.insert(class, total_samples / (n_classes as f64 * count as f64));
1015        }
1016
1017        // Apply sampling strategy (simplified for now)
1018        let X_resampled = X.clone();
1019        let y_resampled = y.clone();
1020
1021        // Create bootstrap samples if balanced bootstrap is enabled
1022        let bootstrap_samples = if self.config.balanced_bootstrap {
1023            // Simplified balanced bootstrap - create multiple samples
1024            let mut samples = Vec::new();
1025            for _ in 0..self.config.n_estimators {
1026                samples.push((X_resampled.clone(), y_resampled.clone()));
1027            }
1028            samples
1029        } else {
1030            vec![(X_resampled, y_resampled)]
1031        };
1032
1033        // Apply cost-sensitive learning if configured
1034        let adjusted_weights = if let Some(ref cost_config) = self.config.cost_sensitive_config {
1035            self.apply_cost_sensitive_weights(&class_weights, cost_config)?
1036        } else {
1037            class_weights.clone()
1038        };
1039
1040        // Train base classifiers
1041        let mut trained_base_classifiers = Vec::new();
1042        for (X_sample, y_sample) in bootstrap_samples {
1043            // Convert Vec<usize> to Array1<i32> for BaggingClassifier
1044            let y_sample_array = Array1::from_vec(y_sample.iter().map(|&x| x as i32).collect());
1045
1046            // Apply cost-sensitive training
1047            let classifier = if self.config.cost_sensitive_config.is_some() {
1048                // Use weighted sampling based on cost-sensitive weights
1049                BaggingClassifier::new()
1050                    .n_estimators(50)
1051                    .bootstrap(true)
1052                    .fit(&X_sample, &y_sample_array)?
1053            } else {
1054                BaggingClassifier::new()
1055                    .n_estimators(50)
1056                    .bootstrap(true)
1057                    .fit(&X_sample, &y_sample_array)?
1058            };
1059
1060            trained_base_classifiers.push(classifier);
1061        }
1062
1063        // Create fitted instance
1064        Ok(ImbalancedEnsembleClassifier {
1065            config: self.config,
1066            state: std::marker::PhantomData,
1067            base_classifiers: Some(trained_base_classifiers),
1068            class_weights: Some(class_weights),
1069            optimal_thresholds: Some(HashMap::new()), // Will be optimized later
1070            class_distributions: Some(class_dist),
1071            sampling_results: Some(Vec::new()),
1072        })
1073    }
1074}
1075
1076impl Predict<Array2<f64>, Vec<usize>> for ImbalancedEnsembleClassifier<Trained> {
1077    fn predict(&self, X: &Array2<f64>) -> SklResult<Vec<usize>> {
1078        let mut all_predictions = Vec::new();
1079
1080        // Get predictions from all base classifiers
1081        let base_classifiers = self.base_classifiers.as_ref().expect("Model is trained");
1082        for classifier in base_classifiers {
1083            let predictions = classifier.predict(X)?;
1084            let predictions_vec: Vec<usize> = predictions.iter().map(|&x| x as usize).collect();
1085            all_predictions.push(predictions_vec);
1086        }
1087
1088        // Combine predictions based on strategy
1089        match self.config.combination_strategy {
1090            CombinationStrategy::MajorityVoting => self.majority_voting(&all_predictions),
1091            CombinationStrategy::WeightedVoting => self.weighted_voting(&all_predictions),
1092            _ => {
1093                // Default to majority voting
1094                self.majority_voting(&all_predictions)
1095            }
1096        }
1097    }
1098}
1099
1100impl<State> ImbalancedEnsembleClassifier<State> {
1101    /// Combine predictions using majority voting
1102    fn majority_voting(&self, predictions: &[Vec<usize>]) -> SklResult<Vec<usize>> {
1103        if predictions.is_empty() {
1104            return Err(SklearsError::InvalidInput(
1105                "No predictions to combine".to_string(),
1106            ));
1107        }
1108
1109        let n_samples = predictions[0].len();
1110        let mut final_predictions = Vec::with_capacity(n_samples);
1111
1112        for i in 0..n_samples {
1113            let mut class_votes = HashMap::new();
1114
1115            for pred in predictions {
1116                *class_votes.entry(pred[i]).or_insert(0) += 1;
1117            }
1118
1119            let predicted_class = *class_votes
1120                .iter()
1121                .max_by_key(|(_, &count)| count)
1122                .map(|(class, _)| class)
1123                .unwrap();
1124
1125            final_predictions.push(predicted_class);
1126        }
1127
1128        Ok(final_predictions)
1129    }
1130
1131    /// Combine predictions using weighted voting
1132    fn weighted_voting(&self, predictions: &[Vec<usize>]) -> SklResult<Vec<usize>> {
1133        if predictions.is_empty() {
1134            return Err(SklearsError::InvalidInput(
1135                "No predictions to combine".to_string(),
1136            ));
1137        }
1138
1139        let n_samples = predictions[0].len();
1140        let mut final_predictions = Vec::with_capacity(n_samples);
1141
1142        for i in 0..n_samples {
1143            let mut class_scores = HashMap::new();
1144
1145            for (j, pred) in predictions.iter().enumerate() {
1146                let weight = 1.0 / (j + 1) as f64; // Simple weighting scheme
1147                *class_scores.entry(pred[i]).or_insert(0.0) += weight;
1148            }
1149
1150            let predicted_class = *class_scores
1151                .iter()
1152                .max_by(|(_, &score1), (_, &score2)| score1.partial_cmp(&score2).unwrap())
1153                .map(|(class, _)| class)
1154                .unwrap();
1155
1156            final_predictions.push(predicted_class);
1157        }
1158
1159        Ok(final_predictions)
1160    }
1161
1162    /// Get class distribution information
1163    pub fn get_class_distribution(&self) -> &HashMap<usize, usize> {
1164        self.class_distributions
1165            .as_ref()
1166            .expect("Class distributions not available")
1167    }
1168
1169    /// Get computed class weights
1170    pub fn get_class_weights(&self) -> &HashMap<usize, f64> {
1171        self.class_weights
1172            .as_ref()
1173            .expect("Class weights not available")
1174    }
1175
1176    /// Get optimal thresholds for classes
1177    pub fn get_optimal_thresholds(&self) -> &HashMap<usize, f64> {
1178        self.optimal_thresholds
1179            .as_ref()
1180            .expect("Optimal thresholds not available")
1181    }
1182
1183    /// Apply cost-sensitive learning to adjust class weights
1184    fn apply_cost_sensitive_weights(
1185        &self,
1186        base_weights: &HashMap<usize, f64>,
1187        cost_config: &CostSensitiveConfig,
1188    ) -> SklResult<HashMap<usize, f64>> {
1189        let mut adjusted_weights = base_weights.clone();
1190
1191        // Apply custom class weights if specified
1192        if let Some(ref custom_weights) = cost_config.class_weights {
1193            for (&class, &weight) in custom_weights {
1194                adjusted_weights.insert(class, weight);
1195            }
1196        }
1197
1198        // Apply cost matrix adjustments
1199        if cost_config.cost_matrix.nrows() > 0 {
1200            for (&class, weight) in &mut adjusted_weights {
1201                if class < cost_config.cost_matrix.nrows() {
1202                    // Apply cost matrix scaling - higher cost for misclassification means higher weight
1203                    let misclassification_cost = cost_config.cost_matrix.row(class).sum();
1204                    *weight *= misclassification_cost;
1205                }
1206            }
1207        }
1208
1209        // Normalize weights if class-balanced weights are enabled
1210        if cost_config.class_balanced_weights {
1211            let total_weight: f64 = adjusted_weights.values().sum();
1212            let n_classes = adjusted_weights.len() as f64;
1213            let avg_weight = total_weight / n_classes;
1214
1215            for weight in adjusted_weights.values_mut() {
1216                *weight /= avg_weight;
1217            }
1218        }
1219
1220        Ok(adjusted_weights)
1221    }
1222}
1223
1224#[allow(non_snake_case)]
1225#[cfg(test)]
1226mod tests {
1227    use super::*;
1228    use scirs2_core::ndarray::Array2;
1229
1230    #[test]
1231    fn test_imbalanced_config() {
1232        let config = ImbalancedEnsembleConfig::builder()
1233            .n_estimators(5)
1234            .sampling_strategy(SamplingStrategy::SMOTE)
1235            .balanced_bootstrap(true)
1236            .build();
1237
1238        assert_eq!(config.n_estimators, 5);
1239        assert_eq!(config.sampling_strategy, SamplingStrategy::SMOTE);
1240        assert!(config.balanced_bootstrap);
1241    }
1242
1243    #[test]
1244    fn test_class_distribution_analysis() {
1245        let config = ImbalancedEnsembleConfig::default();
1246        let mut classifier = ImbalancedEnsembleClassifier::new(config);
1247
1248        let y = vec![0, 0, 0, 0, 1, 2]; // Imbalanced: 4 class 0, 1 class 1, 1 class 2
1249        classifier.analyze_class_distribution(&y).unwrap();
1250
1251        assert_eq!(classifier.class_distributions.as_ref().unwrap()[&0], 4);
1252        assert_eq!(classifier.class_distributions.as_ref().unwrap()[&1], 1);
1253        assert_eq!(classifier.class_distributions.as_ref().unwrap()[&2], 1);
1254
1255        // Check class weights (should be higher for minority classes)
1256        assert!(
1257            classifier.class_weights.as_ref().unwrap()[&1]
1258                > classifier.class_weights.as_ref().unwrap()[&0]
1259        );
1260        assert!(
1261            classifier.class_weights.as_ref().unwrap()[&2]
1262                > classifier.class_weights.as_ref().unwrap()[&0]
1263        );
1264    }
1265
1266    #[test]
1267    #[allow(non_snake_case)]
1268    fn test_smote_sampler() {
1269        let config = SMOTEConfig {
1270            k_neighbors: 1, // Reduced to 1 since we only have 2 samples of class 1
1271            sampling_strategy: 1.0,
1272            random_state: Some(42),
1273            selective: false,
1274            borderline_mode: BorderlineMode::Borderline1,
1275        };
1276
1277        let mut sampler = SMOTESampler::new(config);
1278
1279        // Create simple imbalanced dataset
1280        let X = Array2::from_shape_vec(
1281            (6, 2),
1282            vec![1.0, 1.0, 1.1, 1.1, 1.2, 1.2, 1.3, 1.3, 5.0, 5.0, 5.1, 5.1],
1283        )
1284        .unwrap();
1285        let y = vec![0, 0, 0, 0, 1, 1]; // Class 0: 4 samples, Class 1: 2 samples
1286
1287        let (X_resampled, y_resampled) = sampler.fit_resample(&X, &y).unwrap();
1288
1289        // Check that minority class was oversampled
1290        let class_1_count = y_resampled.iter().filter(|&&c| c == 1).count();
1291        assert!(class_1_count > 2); // Should have more than original 2 samples
1292
1293        // Check dimensions
1294        assert_eq!(X_resampled.shape()[1], X.shape()[1]); // Same number of features
1295        assert_eq!(X_resampled.shape()[0], y_resampled.len()); // Consistent sample count
1296    }
1297
1298    #[test]
1299    fn test_imbalanced_ensemble_basic() {
1300        let config = ImbalancedEnsembleConfig::builder()
1301            .n_estimators(3)
1302            .sampling_strategy(SamplingStrategy::RandomOverSampling)
1303            .random_state(42)
1304            .build();
1305
1306        let classifier = ImbalancedEnsembleClassifier::new(config);
1307
1308        // Test basic configuration
1309        assert_eq!(classifier.config.n_estimators, 3);
1310        assert_eq!(
1311            classifier.config.sampling_strategy,
1312            SamplingStrategy::RandomOverSampling
1313        );
1314        // In untrained state, base_classifiers should be None
1315        assert!(classifier.base_classifiers.is_none());
1316    }
1317
1318    #[test]
1319    fn test_cost_sensitive_ensemble() {
1320        // Create a simple cost matrix - class 1 misclassification costs 3x more
1321        let cost_matrix = Array2::from_shape_vec((2, 2), vec![1.0, 3.0, 1.0, 1.0]).unwrap();
1322        let classifier = ImbalancedEnsembleClassifier::cost_sensitive(cost_matrix);
1323
1324        // Verify configuration
1325        assert!(classifier.config.cost_sensitive_config.is_some());
1326        let cost_config = classifier.config.cost_sensitive_config.as_ref().unwrap();
1327        assert_eq!(cost_config.cost_matrix.shape(), &[2, 2]);
1328        assert!(cost_config.class_balanced_weights);
1329        assert_eq!(
1330            cost_config.algorithm,
1331            CostSensitiveAlgorithm::CostSensitiveBoosting
1332        );
1333    }
1334
1335    #[test]
1336    fn test_cost_sensitive_weights() {
1337        let mut class_weights = HashMap::new();
1338        class_weights.insert(0, 1.0);
1339        class_weights.insert(1, 3.0); // Class 1 gets 3x weight
1340
1341        let classifier =
1342            ImbalancedEnsembleClassifier::cost_sensitive_weights(class_weights.clone());
1343
1344        // Verify configuration
1345        assert!(classifier.config.cost_sensitive_config.is_some());
1346        let cost_config = classifier.config.cost_sensitive_config.as_ref().unwrap();
1347        assert_eq!(cost_config.class_weights, Some(class_weights));
1348        assert!(!cost_config.class_balanced_weights);
1349    }
1350
1351    #[test]
1352    fn test_smote_ensemble_creation() {
1353        let classifier = ImbalancedEnsembleClassifier::smote_ensemble(3);
1354
1355        // Verify configuration
1356        assert_eq!(classifier.config.sampling_strategy, SamplingStrategy::SMOTE);
1357        assert!(classifier.config.smote_config.is_some());
1358        let smote_config = classifier.config.smote_config.as_ref().unwrap();
1359        assert_eq!(smote_config.k_neighbors, 3);
1360        assert!(classifier.config.balanced_bootstrap);
1361    }
1362}