sklears_ensemble/
bagging.rs

1//! Bagging ensemble methods
2//!
3//! Bootstrap aggregating (bagging) trains multiple base estimators on random
4//! subsets of the training data and aggregates their predictions.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::prelude::*;
9use sklears_core::error::{Result, SklearsError};
10use sklears_core::prelude::Predict;
11use sklears_core::traits::{Estimator, Fit, Trained, Untrained};
12use sklears_core::types::{Float, Int};
13// use sklears_tree::{DecisionTreeClassifier, DecisionTreeRegressor, SplitCriterion}; // Temporarily disabled
14use crate::adaboost::{DecisionTreeClassifier, DecisionTreeRegressor, SplitCriterion};
15#[allow(unused_imports)]
16use std::collections::HashSet;
17use std::marker::PhantomData;
18
19#[cfg(feature = "parallel")]
20use rayon::prelude::*;
21
22/// Configuration for bagging ensemble
23#[derive(Debug, Clone)]
24pub struct BaggingConfig {
25    /// Number of base estimators in the ensemble
26    pub n_estimators: usize,
27    /// Number of samples to draw from X to train each base estimator
28    pub max_samples: Option<usize>,
29    /// Number of features to draw from X to train each base estimator
30    pub max_features: Option<usize>,
31    /// Whether to use replacement when sampling
32    pub bootstrap: bool,
33    /// Whether to use replacement when sampling features
34    pub bootstrap_features: bool,
35    /// Random state for reproducible results
36    pub random_state: Option<u64>,
37    /// Out-of-bag score calculation
38    pub oob_score: bool,
39    /// Number of jobs for parallel execution
40    pub n_jobs: Option<i32>,
41    /// Maximum depth for decision tree base estimators
42    pub max_depth: Option<usize>,
43    /// Minimum samples required to split an internal node
44    pub min_samples_split: usize,
45    /// Minimum samples required to be at a leaf node
46    pub min_samples_leaf: usize,
47    /// Bootstrap confidence level for intervals
48    pub confidence_level: Float,
49    /// Use extra randomization (Extremely Randomized Trees)
50    pub extra_randomized: bool,
51}
52
53impl Default for BaggingConfig {
54    fn default() -> Self {
55        Self {
56            n_estimators: 10,
57            max_samples: None,
58            max_features: None,
59            bootstrap: true,
60            bootstrap_features: false,
61            random_state: None,
62            oob_score: false,
63            n_jobs: None,
64            max_depth: None,
65            min_samples_split: 2,
66            min_samples_leaf: 1,
67            confidence_level: 0.95,
68            extra_randomized: false,
69        }
70    }
71}
72
73/// Enhanced Bagging classifier with OOB estimation and feature bagging
74pub struct BaggingClassifier<State = Untrained> {
75    config: BaggingConfig,
76    state: PhantomData<State>,
77    // Fitted parameters
78    estimators_: Option<Vec<DecisionTreeClassifier<Trained>>>,
79    estimators_features_: Option<Vec<Vec<usize>>>,
80    estimators_samples_: Option<Vec<Vec<usize>>>,
81    oob_score_: Option<Float>,
82    oob_prediction_: Option<Array1<Float>>,
83    classes_: Option<Array1<Int>>,
84    n_classes_: Option<usize>,
85    n_features_in_: Option<usize>,
86    feature_importances_: Option<Array1<Float>>,
87}
88
89impl BaggingClassifier<Untrained> {
90    /// Create a new bagging classifier
91    pub fn new() -> Self {
92        Self {
93            config: BaggingConfig::default(),
94            state: PhantomData,
95            estimators_: None,
96            estimators_features_: None,
97            estimators_samples_: None,
98            oob_score_: None,
99            oob_prediction_: None,
100            classes_: None,
101            n_classes_: None,
102            n_features_in_: None,
103            feature_importances_: None,
104        }
105    }
106
107    /// Set the number of estimators
108    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
109        self.config.n_estimators = n_estimators;
110        self
111    }
112
113    /// Set the maximum number of samples per estimator
114    pub fn max_samples(mut self, max_samples: Option<usize>) -> Self {
115        self.config.max_samples = max_samples;
116        self
117    }
118
119    /// Set the maximum number of features per estimator
120    pub fn max_features(mut self, max_features: Option<usize>) -> Self {
121        self.config.max_features = max_features;
122        self
123    }
124
125    /// Set whether to use bootstrap sampling
126    pub fn bootstrap(mut self, bootstrap: bool) -> Self {
127        self.config.bootstrap = bootstrap;
128        self
129    }
130
131    /// Set whether to use bootstrap feature sampling
132    pub fn bootstrap_features(mut self, bootstrap_features: bool) -> Self {
133        self.config.bootstrap_features = bootstrap_features;
134        self
135    }
136
137    /// Set the random state
138    pub fn random_state(mut self, random_state: u64) -> Self {
139        self.config.random_state = Some(random_state);
140        self
141    }
142
143    /// Set whether to calculate out-of-bag score
144    pub fn oob_score(mut self, oob_score: bool) -> Self {
145        self.config.oob_score = oob_score;
146        self
147    }
148
149    /// Set maximum depth for base estimators
150    pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
151        self.config.max_depth = max_depth;
152        self
153    }
154
155    /// Set minimum samples to split
156    pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
157        self.config.min_samples_split = min_samples_split;
158        self
159    }
160
161    /// Set minimum samples at leaf
162    pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
163        self.config.min_samples_leaf = min_samples_leaf;
164        self
165    }
166
167    /// Set confidence level for bootstrap intervals
168    pub fn confidence_level(mut self, confidence_level: Float) -> Self {
169        self.config.confidence_level = confidence_level;
170        self
171    }
172
173    /// Set number of parallel jobs for training (None for automatic detection)
174    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
175        self.config.n_jobs = n_jobs;
176        self
177    }
178
179    /// Enable parallel training with automatic job detection
180    pub fn parallel(mut self) -> Self {
181        self.config.n_jobs = Some(-1); // -1 means use all available cores
182        self
183    }
184
185    /// Enable extra randomization (Extremely Randomized Trees)
186    pub fn extra_randomized(mut self, extra_randomized: bool) -> Self {
187        self.config.extra_randomized = extra_randomized;
188        self
189    }
190
191    /// Enable extra randomization (convenient shorthand)
192    pub fn extremely_randomized(mut self) -> Self {
193        self.config.extra_randomized = true;
194        self.config.bootstrap = false; // Extra trees typically don't use bootstrap
195        self
196    }
197
198    /// Generate bootstrap samples with optional out-of-bag tracking
199    /// For extra randomized trees, returns the full dataset instead of bootstrap samples
200    fn bootstrap_sample(
201        &self,
202        x: &Array2<Float>,
203        y: &Array1<Int>,
204        rng: &mut StdRng,
205    ) -> Result<(Array2<Float>, Array1<Int>, Vec<usize>)> {
206        let n_samples = x.nrows();
207
208        // For extra randomized trees, use the full dataset
209        if self.config.extra_randomized {
210            let sample_indices: Vec<usize> = (0..n_samples).collect();
211            return Ok((x.clone(), y.clone(), sample_indices));
212        }
213
214        let max_samples = self.config.max_samples.unwrap_or(n_samples);
215
216        // Find unique classes and their indices
217        let mut class_indices: std::collections::HashMap<Int, Vec<usize>> =
218            std::collections::HashMap::new();
219        for (idx, &class) in y.iter().enumerate() {
220            class_indices.entry(class).or_default().push(idx);
221        }
222
223        let mut sample_indices = Vec::new();
224
225        if self.config.bootstrap {
226            // First, do standard bootstrap sampling
227            sample_indices = (0..max_samples)
228                .map(|_| rng.gen_range(0..n_samples))
229                .collect();
230        } else {
231            // Random sampling without replacement
232            let mut indices: Vec<usize> = (0..n_samples).collect();
233            indices.shuffle(rng);
234            indices.truncate(max_samples);
235            sample_indices = indices;
236        }
237
238        // Verify we have at least 2 classes in the sample
239        let mut sampled_classes = std::collections::HashSet::new();
240        for &idx in &sample_indices {
241            sampled_classes.insert(y[idx]);
242        }
243
244        // If we still have only one class, force diversity by replacing some samples
245        if sampled_classes.len() < 2 && class_indices.len() >= 2 {
246            let mut other_classes: Vec<Int> = class_indices
247                .keys()
248                .filter(|&&c| !sampled_classes.contains(&c))
249                .cloned()
250                .collect();
251
252            // Sort for deterministic behavior
253            other_classes.sort();
254
255            if !other_classes.is_empty() {
256                // Replace the last sample with one from the first other class (deterministic)
257                let other_class = other_classes[0];
258                if let Some(other_indices) = class_indices.get(&other_class) {
259                    if !other_indices.is_empty() {
260                        // Use the first available index for deterministic behavior
261                        let replacement_idx = other_indices[0];
262                        if let Some(last) = sample_indices.last_mut() {
263                            *last = replacement_idx;
264                        }
265                    }
266                }
267            }
268        }
269
270        let mut x_bootstrap = Array2::zeros((max_samples, x.ncols()));
271        let mut y_bootstrap = Array1::zeros(max_samples);
272
273        for (i, &idx) in sample_indices.iter().enumerate() {
274            x_bootstrap.row_mut(i).assign(&x.row(idx));
275            y_bootstrap[i] = y[idx];
276        }
277
278        Ok((x_bootstrap, y_bootstrap, sample_indices))
279    }
280
281    /// Train ensemble using parallel processing with work-stealing when available
282    fn train_ensemble_parallel(
283        &self,
284        x: &Array2<Float>,
285        y: &Array1<Int>,
286        rng: &mut StdRng,
287        n_features: usize,
288    ) -> Result<(
289        Vec<DecisionTreeClassifier<Trained>>,
290        Vec<Vec<usize>>,
291        Vec<Vec<usize>>,
292    )> {
293        // Pre-generate all bootstrap samples and feature indices to maintain determinism
294        let mut bootstrap_data = Vec::new();
295        for i in 0..self.config.n_estimators {
296            let mut local_rng =
297                StdRng::seed_from_u64(self.config.random_state.unwrap_or(42) + i as u64);
298
299            let (x_bootstrap, y_bootstrap, sample_indices) =
300                self.bootstrap_sample(x, y, &mut local_rng)?;
301            let feature_indices = self.get_feature_indices(n_features, &mut local_rng);
302
303            bootstrap_data.push((x_bootstrap, y_bootstrap, sample_indices, feature_indices));
304        }
305
306        // Determine whether to use parallel processing
307        let use_parallel = self.should_use_parallel();
308
309        if use_parallel {
310            #[cfg(feature = "parallel")]
311            {
312                // Parallel training with work-stealing using rayon
313                let results: Result<Vec<_>> = bootstrap_data
314                    .into_par_iter()
315                    .enumerate()
316                    .map(
317                        |(i, (x_bootstrap, y_bootstrap, sample_indices, feature_indices))| {
318                            self.fit_single_estimator(
319                                &x_bootstrap,
320                                &y_bootstrap,
321                                &feature_indices,
322                                i,
323                            )
324                            .map(|estimator| (estimator, feature_indices, sample_indices))
325                        },
326                    )
327                    .collect();
328
329                let fitted_data = results?;
330                let (estimators, estimators_features, estimators_samples) =
331                    fitted_data.into_iter().fold(
332                        (Vec::new(), Vec::new(), Vec::new()),
333                        |(mut e, mut ef, mut es), (estimator, features, samples)| {
334                            e.push(estimator);
335                            ef.push(features);
336                            es.push(samples);
337                            (e, ef, es)
338                        },
339                    );
340
341                Ok((estimators, estimators_features, estimators_samples))
342            }
343            #[cfg(not(feature = "parallel"))]
344            {
345                // Fall back to sequential if parallel feature is not enabled
346                self.train_ensemble_sequential(bootstrap_data)
347            }
348        } else {
349            // Sequential training
350            self.train_ensemble_sequential(bootstrap_data)
351        }
352    }
353
354    /// Train ensemble sequentially
355    fn train_ensemble_sequential(
356        &self,
357        bootstrap_data: Vec<(Array2<Float>, Array1<Int>, Vec<usize>, Vec<usize>)>,
358    ) -> Result<(
359        Vec<DecisionTreeClassifier<Trained>>,
360        Vec<Vec<usize>>,
361        Vec<Vec<usize>>,
362    )> {
363        let mut estimators = Vec::new();
364        let mut estimators_features = Vec::new();
365        let mut estimators_samples = Vec::new();
366
367        for (i, (x_bootstrap, y_bootstrap, sample_indices, feature_indices)) in
368            bootstrap_data.into_iter().enumerate()
369        {
370            let fitted_tree =
371                self.fit_single_estimator(&x_bootstrap, &y_bootstrap, &feature_indices, i)?;
372
373            estimators.push(fitted_tree);
374            estimators_features.push(feature_indices);
375            estimators_samples.push(sample_indices);
376        }
377
378        Ok((estimators, estimators_features, estimators_samples))
379    }
380
381    /// Fit a single estimator with given bootstrap sample and feature indices
382    fn fit_single_estimator(
383        &self,
384        x_bootstrap: &Array2<Float>,
385        y_bootstrap: &Array1<Int>,
386        feature_indices: &[usize],
387        estimator_index: usize,
388    ) -> Result<DecisionTreeClassifier<Trained>> {
389        // Extract features for this estimator
390        let mut x_features = Array2::zeros((x_bootstrap.nrows(), feature_indices.len()));
391        for (j, &feature_idx) in feature_indices.iter().enumerate() {
392            x_features
393                .column_mut(j)
394                .assign(&x_bootstrap.column(feature_idx));
395        }
396
397        // Create and configure base estimator (decision tree)
398        let mut tree = DecisionTreeClassifier::new()
399            .criterion(SplitCriterion::Gini)
400            .min_samples_split(self.config.min_samples_split)
401            .min_samples_leaf(self.config.min_samples_leaf);
402
403        if let Some(max_depth) = self.config.max_depth {
404            tree = tree.max_depth(max_depth);
405        }
406
407        if let Some(seed) = self.config.random_state.map(|s| s + estimator_index as u64) {
408            tree = tree.random_state(Some(seed));
409        }
410
411        // Fit the tree
412        tree.fit(&x_features, y_bootstrap)
413    }
414
415    /// Determine whether to use parallel processing based on configuration
416    fn should_use_parallel(&self) -> bool {
417        match self.config.n_jobs {
418            Some(n) if n != 1 => true, // Use parallel if n_jobs is set and not 1
419            None => false,             // Don't use parallel if not specified
420            _ => false,                // n_jobs == Some(1), use sequential
421        }
422    }
423
424    /// Generate feature indices for feature bagging
425    fn get_feature_indices(&self, n_features: usize, rng: &mut StdRng) -> Vec<usize> {
426        let max_features = self.config.max_features.unwrap_or(n_features);
427        let mut feature_indices: Vec<usize> = (0..n_features).collect();
428
429        if self.config.bootstrap_features {
430            // Sample features with replacement
431            feature_indices = (0..max_features)
432                .map(|_| rng.gen_range(0..n_features))
433                .collect();
434        } else {
435            // Sample features without replacement
436            feature_indices.shuffle(rng);
437            feature_indices.truncate(max_features);
438        }
439
440        feature_indices.sort_unstable();
441        feature_indices
442    }
443
444    /// Calculate out-of-bag predictions for OOB score
445    fn calculate_oob_predictions(
446        &self,
447        x: &Array2<Float>,
448        y: &Array1<Int>,
449        estimators: &[DecisionTreeClassifier<Trained>],
450        estimators_features: &[Vec<usize>],
451        estimators_samples: &[Vec<usize>],
452    ) -> Result<Float> {
453        let n_samples = x.nrows();
454        let mut oob_predictions: Array1<Float> = Array1::zeros(n_samples);
455        let mut oob_counts: Array1<Float> = Array1::zeros(n_samples);
456
457        for (estimator_idx, (estimator, (features, samples))) in estimators
458            .iter()
459            .zip(estimators_features.iter().zip(estimators_samples.iter()))
460            .enumerate()
461        {
462            // Find out-of-bag samples (samples not used in training)
463            let mut oob_mask = vec![true; n_samples];
464            for &sample_idx in samples {
465                if sample_idx < n_samples {
466                    oob_mask[sample_idx] = false;
467                }
468            }
469
470            // Predict on out-of-bag samples
471            for (sample_idx, &is_oob) in oob_mask.iter().enumerate() {
472                if is_oob {
473                    // Extract features for this sample
474                    let x_sample = x.row(sample_idx);
475                    let x_features = Array2::from_shape_vec(
476                        (1, features.len()),
477                        features.iter().map(|&f| x_sample[f]).collect(),
478                    )
479                    .map_err(|_| {
480                        SklearsError::InvalidInput("Failed to create feature subset".to_string())
481                    })?;
482
483                    let pred = estimator.predict(&x_features)?;
484                    oob_predictions[sample_idx] += pred[0] as Float;
485                    oob_counts[sample_idx] += 1.0;
486                }
487            }
488        }
489
490        // Calculate OOB score (accuracy)
491        let mut correct = 0;
492        let mut total = 0;
493
494        for i in 0..n_samples {
495            if oob_counts[i] > 0.0 {
496                let ratio: Float = oob_predictions[i] / oob_counts[i];
497                let predicted_class: Int = ratio.round() as Int;
498                if predicted_class == y[i] {
499                    correct += 1;
500                }
501                total += 1;
502            }
503        }
504
505        if total == 0 {
506            Ok(0.0)
507        } else {
508            Ok(correct as Float / total as Float)
509        }
510    }
511}
512
513impl Fit<Array2<Float>, Array1<Int>> for BaggingClassifier<Untrained> {
514    type Fitted = BaggingClassifier<Trained>;
515
516    fn fit(self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Self::Fitted> {
517        let (n_samples, n_features) = x.dim();
518
519        if n_samples != y.len() {
520            return Err(SklearsError::ShapeMismatch {
521                expected: format!("X.shape[0] = {}", n_samples),
522                actual: format!("y.shape[0] = {}", y.len()),
523            });
524        }
525
526        if n_samples == 0 {
527            return Err(SklearsError::InvalidInput(
528                "Cannot fit bagging on empty dataset".to_string(),
529            ));
530        }
531
532        // Find unique classes
533        let mut unique_classes: Vec<Int> = y.iter().cloned().collect();
534        unique_classes.sort_unstable();
535        unique_classes.dedup();
536        let classes = Array1::from_vec(unique_classes);
537        let n_classes = classes.len();
538
539        if n_classes < 2 {
540            return Err(SklearsError::InvalidInput(
541                "Bagging requires at least 2 classes".to_string(),
542            ));
543        }
544
545        // Initialize random number generator
546        let mut rng = match self.config.random_state {
547            Some(seed) => StdRng::seed_from_u64(seed),
548            None => StdRng::seed_from_u64(42), // Use fallback seed for entropy
549        };
550
551        // Train ensemble (parallel or sequential based on configuration)
552        let (estimators, estimators_features, estimators_samples) =
553            self.train_ensemble_parallel(x, y, &mut rng, n_features)?;
554
555        // Calculate out-of-bag score if requested
556        let oob_score = if self.config.oob_score {
557            Some(self.calculate_oob_predictions(
558                x,
559                y,
560                &estimators,
561                &estimators_features,
562                &estimators_samples,
563            )?)
564        } else {
565            None
566        };
567
568        // Calculate feature importances (average over all trees)
569        let mut feature_importances = Array1::zeros(n_features);
570        for (estimator, features) in estimators.iter().zip(estimators_features.iter()) {
571            // For now, use uniform importance within selected features
572            let tree_importance = 1.0 / features.len() as Float;
573            for &feature_idx in features {
574                feature_importances[feature_idx] += tree_importance;
575            }
576        }
577
578        // Normalize feature importances
579        let total_importance = feature_importances.sum();
580        if total_importance > 0.0 {
581            feature_importances /= total_importance;
582        }
583
584        Ok(BaggingClassifier {
585            config: self.config,
586            state: PhantomData,
587            estimators_: Some(estimators),
588            estimators_features_: Some(estimators_features.to_vec()),
589            estimators_samples_: Some(estimators_samples.to_vec()),
590            oob_score_: oob_score,
591            oob_prediction_: None,
592            classes_: Some(classes),
593            n_classes_: Some(n_classes),
594            n_features_in_: Some(n_features),
595            feature_importances_: Some(feature_importances),
596        })
597    }
598}
599
600impl BaggingClassifier<Trained> {
601    /// Get the fitted base estimators
602    pub fn estimators(&self) -> &[DecisionTreeClassifier<Trained>] {
603        self.estimators_
604            .as_ref()
605            .expect("BaggingClassifier should be fitted")
606    }
607
608    /// Get the feature indices used by each estimator
609    pub fn estimators_features(&self) -> &[Vec<usize>] {
610        self.estimators_features_
611            .as_ref()
612            .expect("BaggingClassifier should be fitted")
613    }
614
615    /// Get the sample indices used by each estimator
616    pub fn estimators_samples(&self) -> &[Vec<usize>] {
617        self.estimators_samples_
618            .as_ref()
619            .expect("BaggingClassifier should be fitted")
620    }
621
622    /// Get the out-of-bag score if calculated
623    pub fn oob_score(&self) -> Option<Float> {
624        self.oob_score_
625    }
626
627    /// Get the classes
628    pub fn classes(&self) -> &Array1<Int> {
629        self.classes_
630            .as_ref()
631            .expect("BaggingClassifier should be fitted")
632    }
633
634    /// Get the number of classes
635    pub fn n_classes(&self) -> usize {
636        self.n_classes_.expect("BaggingClassifier should be fitted")
637    }
638
639    /// Get the number of input features
640    pub fn n_features_in(&self) -> usize {
641        self.n_features_in_
642            .expect("BaggingClassifier should be fitted")
643    }
644
645    /// Get feature importances
646    pub fn feature_importances(&self) -> &Array1<Float> {
647        self.feature_importances_
648            .as_ref()
649            .expect("BaggingClassifier should be fitted")
650    }
651
652    /// Calculate bootstrap confidence intervals for predictions
653    pub fn predict_with_confidence(
654        &self,
655        x: &Array2<Float>,
656    ) -> Result<(Array1<Int>, Array2<Float>)> {
657        let (n_samples, n_features) = x.dim();
658
659        if n_features != self.n_features_in() {
660            return Err(SklearsError::FeatureMismatch {
661                expected: self.n_features_in(),
662                actual: n_features,
663            });
664        }
665
666        let estimators = self.estimators();
667        let estimators_features = self.estimators_features();
668        let classes = self.classes();
669        let n_classes = self.n_classes();
670        let n_estimators = estimators.len();
671
672        let mut all_predictions = Array2::zeros((n_samples, n_estimators));
673
674        // Get predictions from all estimators
675        for (estimator_idx, (estimator, features)) in estimators
676            .iter()
677            .zip(estimators_features.iter())
678            .enumerate()
679        {
680            // Extract features for this estimator
681            let mut x_features = Array2::zeros((n_samples, features.len()));
682            for (j, &feature_idx) in features.iter().enumerate() {
683                x_features.column_mut(j).assign(&x.column(feature_idx));
684            }
685
686            let predictions = estimator.predict(&x_features)?;
687
688            // Validate prediction array size
689            if predictions.len() != n_samples {
690                return Err(SklearsError::ShapeMismatch {
691                    expected: format!("{} predictions", n_samples),
692                    actual: format!("{} predictions", predictions.len()),
693                });
694            }
695
696            for i in 0..n_samples {
697                all_predictions[[i, estimator_idx]] = predictions[i] as Float;
698            }
699        }
700
701        // Calculate final predictions and confidence intervals
702        let mut final_predictions = Array1::zeros(n_samples);
703        let mut confidence_intervals = Array2::zeros((n_samples, 2)); // [lower, upper] bounds
704
705        for i in 0..n_samples {
706            let sample_predictions = all_predictions.row(i);
707
708            // Mode for final prediction
709            let mut class_counts = vec![0; n_classes];
710            for &pred in sample_predictions {
711                let class_idx = classes.iter().position(|&c| c == pred as Int).unwrap_or(0);
712                class_counts[class_idx] += 1;
713            }
714
715            let max_class_idx = class_counts
716                .iter()
717                .enumerate()
718                .max_by(|(_, a), (_, b)| a.cmp(b))
719                .map(|(idx, _)| idx)
720                .unwrap_or(0);
721            final_predictions[i] = classes[max_class_idx];
722
723            // Bootstrap confidence interval
724            let mut sorted_predictions: Vec<Float> = sample_predictions.to_vec();
725            sorted_predictions.sort_by(|a, b| a.partial_cmp(b).unwrap());
726
727            let alpha = 1.0 - self.config.confidence_level;
728            let lower_idx = ((alpha / 2.0) * n_estimators as Float) as usize;
729            let upper_idx = ((1.0 - alpha / 2.0) * n_estimators as Float) as usize;
730
731            confidence_intervals[[i, 0]] = sorted_predictions[lower_idx.min(n_estimators - 1)];
732            confidence_intervals[[i, 1]] = sorted_predictions[upper_idx.min(n_estimators - 1)];
733        }
734
735        Ok((final_predictions, confidence_intervals))
736    }
737}
738
739impl Predict<Array2<Float>, Array1<Int>> for BaggingClassifier<Trained> {
740    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Int>> {
741        let (n_samples, n_features) = x.dim();
742
743        if n_features != self.n_features_in() {
744            return Err(SklearsError::FeatureMismatch {
745                expected: self.n_features_in(),
746                actual: n_features,
747            });
748        }
749
750        let estimators = self.estimators();
751        let estimators_features = self.estimators_features();
752        let classes = self.classes();
753        let n_classes = self.n_classes();
754
755        let mut class_votes = Array2::zeros((n_samples, n_classes));
756
757        // Aggregate predictions from all estimators
758        for (estimator, features) in estimators.iter().zip(estimators_features.iter()) {
759            // Extract features for this estimator
760            let mut x_features = Array2::zeros((n_samples, features.len()));
761            for (j, &feature_idx) in features.iter().enumerate() {
762                x_features.column_mut(j).assign(&x.column(feature_idx));
763            }
764
765            let predictions = estimator.predict(&x_features)?;
766
767            // Count votes for each class
768            for (i, &pred) in predictions.iter().enumerate() {
769                if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
770                    class_votes[[i, class_idx]] += 1.0;
771                }
772            }
773        }
774
775        // Select class with most votes
776        let mut final_predictions = Array1::zeros(n_samples);
777        for i in 0..n_samples {
778            let max_idx = class_votes
779                .row(i)
780                .iter()
781                .enumerate()
782                .max_by(|(_, a): &(_, &Float), (_, b): &(_, &Float)| a.partial_cmp(b).unwrap())
783                .map(|(idx, _)| idx)
784                .unwrap_or(0);
785            final_predictions[i] = classes[max_idx];
786        }
787
788        Ok(final_predictions)
789    }
790}
791
792impl Default for BaggingClassifier<Untrained> {
793    fn default() -> Self {
794        Self::new()
795    }
796}
797
798impl<State> Estimator<State> for BaggingClassifier<State> {
799    type Config = BaggingConfig;
800    type Error = SklearsError;
801    type Float = Float;
802
803    fn config(&self) -> &Self::Config {
804        &self.config
805    }
806
807    fn validate_config(&self) -> Result<()> {
808        if self.config.n_estimators == 0 {
809            return Err(SklearsError::InvalidInput(
810                "n_estimators must be greater than 0".to_string(),
811            ));
812        }
813
814        if let Some(max_samples) = self.config.max_samples {
815            if max_samples == 0 {
816                return Err(SklearsError::InvalidInput(
817                    "max_samples must be greater than 0".to_string(),
818                ));
819            }
820        }
821
822        if let Some(max_features) = self.config.max_features {
823            if max_features == 0 {
824                return Err(SklearsError::InvalidInput(
825                    "max_features must be greater than 0".to_string(),
826                ));
827            }
828        }
829
830        if self.config.min_samples_split < 2 {
831            return Err(SklearsError::InvalidInput(
832                "min_samples_split must be at least 2".to_string(),
833            ));
834        }
835
836        if self.config.min_samples_leaf < 1 {
837            return Err(SklearsError::InvalidInput(
838                "min_samples_leaf must be at least 1".to_string(),
839            ));
840        }
841
842        if self.config.confidence_level <= 0.0 || self.config.confidence_level >= 1.0 {
843            return Err(SklearsError::InvalidInput(
844                "confidence_level must be between 0.0 and 1.0".to_string(),
845            ));
846        }
847
848        Ok(())
849    }
850
851    fn metadata(&self) -> sklears_core::traits::EstimatorMetadata {
852        sklears_core::traits::EstimatorMetadata {
853            name: "BaggingClassifier".to_string(),
854            version: env!("CARGO_PKG_VERSION").to_string(),
855            description: "Bootstrap aggregating (bagging) classifier".to_string(),
856            supports_sparse: false,
857            supports_multiclass: true,
858            supports_multilabel: false,
859            requires_positive_input: false,
860            supports_online_learning: false,
861            supports_feature_importance: true,
862            memory_complexity: sklears_core::traits::MemoryComplexity::Linear,
863            time_complexity: sklears_core::traits::TimeComplexity::LogLinear,
864        }
865    }
866}
867
868/// Enhanced Bagging regressor with OOB estimation and feature bagging
869pub struct BaggingRegressor<State = Untrained> {
870    config: BaggingConfig,
871    state: PhantomData<State>,
872    // Fitted parameters
873    estimators_: Option<Vec<DecisionTreeRegressor<Trained>>>,
874    estimators_features_: Option<Vec<Vec<usize>>>,
875    estimators_samples_: Option<Vec<Vec<usize>>>,
876    oob_score_: Option<Float>,
877    n_features_in_: Option<usize>,
878    feature_importances_: Option<Array1<Float>>,
879}
880
881impl BaggingRegressor<Untrained> {
882    /// Create a new bagging regressor
883    pub fn new() -> Self {
884        Self {
885            config: BaggingConfig::default(),
886            state: PhantomData,
887            estimators_: None,
888            estimators_features_: None,
889            estimators_samples_: None,
890            oob_score_: None,
891            n_features_in_: None,
892            feature_importances_: None,
893        }
894    }
895
896    /// Set the number of estimators
897    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
898        self.config.n_estimators = n_estimators;
899        self
900    }
901
902    /// Set the random state
903    pub fn random_state(mut self, random_state: u64) -> Self {
904        self.config.random_state = Some(random_state);
905        self
906    }
907
908    /// Set whether to calculate out-of-bag score
909    pub fn oob_score(mut self, oob_score: bool) -> Self {
910        self.config.oob_score = oob_score;
911        self
912    }
913}
914
915impl BaggingRegressor<Trained> {
916    /// Get the out-of-bag score if calculated
917    pub fn oob_score(&self) -> Option<Float> {
918        self.oob_score_
919    }
920
921    /// Get the number of input features
922    pub fn n_features_in(&self) -> usize {
923        self.n_features_in_
924            .expect("BaggingRegressor should be fitted")
925    }
926
927    /// Get feature importances
928    pub fn feature_importances(&self) -> &Array1<Float> {
929        self.feature_importances_
930            .as_ref()
931            .expect("BaggingRegressor should be fitted")
932    }
933}
934
935impl Default for BaggingRegressor<Untrained> {
936    fn default() -> Self {
937        Self::new()
938    }
939}
940
941impl<State> Estimator<State> for BaggingRegressor<State> {
942    type Config = BaggingConfig;
943    type Error = SklearsError;
944    type Float = Float;
945
946    fn config(&self) -> &Self::Config {
947        &self.config
948    }
949
950    fn validate_config(&self) -> Result<()> {
951        if self.config.n_estimators == 0 {
952            return Err(SklearsError::InvalidInput(
953                "n_estimators must be greater than 0".to_string(),
954            ));
955        }
956
957        if let Some(max_samples) = self.config.max_samples {
958            if max_samples == 0 {
959                return Err(SklearsError::InvalidInput(
960                    "max_samples must be greater than 0".to_string(),
961                ));
962            }
963        }
964
965        if let Some(max_features) = self.config.max_features {
966            if max_features == 0 {
967                return Err(SklearsError::InvalidInput(
968                    "max_features must be greater than 0".to_string(),
969                ));
970            }
971        }
972
973        if self.config.min_samples_split < 2 {
974            return Err(SklearsError::InvalidInput(
975                "min_samples_split must be at least 2".to_string(),
976            ));
977        }
978
979        if self.config.min_samples_leaf < 1 {
980            return Err(SklearsError::InvalidInput(
981                "min_samples_leaf must be at least 1".to_string(),
982            ));
983        }
984
985        if self.config.confidence_level <= 0.0 || self.config.confidence_level >= 1.0 {
986            return Err(SklearsError::InvalidInput(
987                "confidence_level must be between 0.0 and 1.0".to_string(),
988            ));
989        }
990
991        Ok(())
992    }
993
994    fn metadata(&self) -> sklears_core::traits::EstimatorMetadata {
995        sklears_core::traits::EstimatorMetadata {
996            name: "BaggingRegressor".to_string(),
997            version: env!("CARGO_PKG_VERSION").to_string(),
998            description: "Bootstrap aggregating (bagging) regressor".to_string(),
999            supports_sparse: false,
1000            supports_multiclass: false,
1001            supports_multilabel: false,
1002            requires_positive_input: false,
1003            supports_online_learning: false,
1004            supports_feature_importance: true,
1005            memory_complexity: sklears_core::traits::MemoryComplexity::Linear,
1006            time_complexity: sklears_core::traits::TimeComplexity::LogLinear,
1007        }
1008    }
1009}
1010
1011#[allow(non_snake_case)]
1012#[cfg(test)]
1013mod tests {
1014    use super::*;
1015    use scirs2_core::ndarray::array;
1016    use sklears_core::traits::Predict;
1017
1018    // Property-based tests using proptest
1019    use proptest::prelude::*;
1020
1021    #[test]
1022    fn test_bagging_classifier_creation() {
1023        let classifier = BaggingClassifier::new()
1024            .n_estimators(20)
1025            .random_state(42)
1026            .oob_score(true);
1027
1028        assert_eq!(classifier.config.n_estimators, 20);
1029        assert_eq!(classifier.config.random_state, Some(42));
1030        assert_eq!(classifier.config.oob_score, true);
1031    }
1032
1033    #[test]
1034    fn test_bagging_classifier_fit_predict() {
1035        let x = array![
1036            [1.0, 2.0],
1037            [2.0, 3.0],
1038            [3.0, 4.0],
1039            [4.0, 5.0],
1040            [5.0, 6.0],
1041            [6.0, 7.0],
1042            [7.0, 8.0],
1043            [8.0, 9.0],
1044        ];
1045        let y = array![0, 0, 1, 1, 2, 2, 0, 1];
1046
1047        let classifier = BaggingClassifier::new().n_estimators(5).random_state(42);
1048
1049        let fitted = classifier.fit(&x, &y).unwrap();
1050        let predictions = fitted.predict(&x).unwrap();
1051
1052        assert_eq!(predictions.len(), 8);
1053        assert_eq!(fitted.n_classes(), 3);
1054        assert_eq!(fitted.classes().len(), 3);
1055        assert_eq!(fitted.n_features_in(), 2);
1056    }
1057
1058    #[test]
1059    fn test_bagging_classifier_with_oob() {
1060        let x = array![
1061            [1.0, 2.0],
1062            [2.0, 3.0],
1063            [3.0, 4.0],
1064            [4.0, 5.0],
1065            [5.0, 6.0],
1066            [6.0, 7.0],
1067            [7.0, 8.0],
1068            [8.0, 9.0],
1069            [9.0, 10.0],
1070            [10.0, 11.0],
1071        ];
1072        let y = array![0, 0, 1, 1, 2, 2, 0, 1, 2, 0];
1073
1074        let classifier = BaggingClassifier::new()
1075            .n_estimators(10)
1076            .random_state(42)
1077            .oob_score(true)
1078            .bootstrap(true);
1079
1080        let fitted = classifier.fit(&x, &y).unwrap();
1081
1082        assert!(fitted.oob_score().is_some());
1083        let oob_score = fitted.oob_score().unwrap();
1084        assert!(oob_score >= 0.0 && oob_score <= 1.0);
1085
1086        let predictions = fitted.predict(&x).unwrap();
1087        assert_eq!(predictions.len(), 10);
1088    }
1089
1090    #[test]
1091    fn test_bagging_classifier_feature_bagging() {
1092        let x = array![
1093            [1.0, 2.0, 3.0, 4.0],
1094            [2.0, 3.0, 4.0, 5.0],
1095            [3.0, 4.0, 5.0, 6.0],
1096            [4.0, 5.0, 6.0, 7.0],
1097            [5.0, 6.0, 7.0, 8.0],
1098            [6.0, 7.0, 8.0, 9.0],
1099        ];
1100        let y = array![0, 0, 1, 1, 2, 2];
1101
1102        let classifier = BaggingClassifier::new()
1103            .n_estimators(5)
1104            .max_features(Some(2)) // Use only 2 features per estimator
1105            .bootstrap_features(false)
1106            .random_state(42);
1107
1108        let fitted = classifier.fit(&x, &y).unwrap();
1109        let predictions = fitted.predict(&x).unwrap();
1110
1111        assert_eq!(predictions.len(), 6);
1112        assert_eq!(fitted.n_features_in(), 4);
1113
1114        // Check that feature importances are calculated
1115        let importances = fitted.feature_importances();
1116        assert_eq!(importances.len(), 4);
1117
1118        // Feature importances should sum to 1
1119        let sum: Float = importances.sum();
1120        assert!((sum - 1.0).abs() < 1e-10);
1121    }
1122
1123    #[test]
1124    fn test_bagging_classifier_confidence_intervals() {
1125        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
1126        let y = array![0, 0, 1, 1];
1127
1128        let classifier = BaggingClassifier::new()
1129            .n_estimators(10)
1130            .random_state(42)
1131            .confidence_level(0.95);
1132
1133        let fitted = classifier.fit(&x, &y).unwrap();
1134        let (predictions, confidence_intervals) = fitted.predict_with_confidence(&x).unwrap();
1135
1136        assert_eq!(predictions.len(), 4);
1137        assert_eq!(confidence_intervals.dim(), (4, 2));
1138
1139        // Check that lower bound <= upper bound
1140        for i in 0..4 {
1141            assert!(confidence_intervals[[i, 0]] <= confidence_intervals[[i, 1]]);
1142        }
1143    }
1144
1145    #[test]
1146    fn test_bagging_regressor_creation() {
1147        let regressor = BaggingRegressor::new().n_estimators(15).random_state(123);
1148
1149        assert_eq!(regressor.config.n_estimators, 15);
1150        assert_eq!(regressor.config.random_state, Some(123));
1151    }
1152
1153    #[test]
1154    fn test_bagging_config_default() {
1155        let config = BaggingConfig::default();
1156
1157        assert_eq!(config.n_estimators, 10);
1158        assert_eq!(config.bootstrap, true);
1159        assert_eq!(config.bootstrap_features, false);
1160        assert_eq!(config.oob_score, false);
1161        assert_eq!(config.random_state, None);
1162        assert_eq!(config.min_samples_split, 2);
1163        assert_eq!(config.min_samples_leaf, 1);
1164        assert_eq!(config.confidence_level, 0.95);
1165    }
1166
1167    #[test]
1168    fn test_bagging_classifier_invalid_input() {
1169        // Empty input
1170        let classifier = BaggingClassifier::new();
1171        let x = Array2::zeros((0, 2));
1172        let y = Array1::zeros(0);
1173        assert!(classifier.fit(&x, &y).is_err());
1174
1175        // Mismatched dimensions
1176        let classifier = BaggingClassifier::new();
1177        let x = Array2::zeros((3, 2));
1178        let y = Array1::zeros(2);
1179        assert!(classifier.fit(&x, &y).is_err());
1180
1181        // Single class
1182        let classifier = BaggingClassifier::new();
1183        let x = array![[1.0, 2.0], [3.0, 4.0]];
1184        let y = array![0, 0];
1185        assert!(classifier.fit(&x, &y).is_err());
1186    }
1187
1188    #[test]
1189    fn test_bagging_classifier_feature_mismatch() {
1190        let x_train = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1191        let y_train = array![0, 1];
1192        let x_test = array![[1.0, 2.0]]; // Wrong number of features
1193
1194        let classifier = BaggingClassifier::new();
1195        let fitted = classifier.fit(&x_train, &y_train).unwrap();
1196        assert!(fitted.predict(&x_test).is_err());
1197    }
1198
1199    // Property-based tests for ensemble properties
1200
1201    proptest! {
1202        #[test]
1203        fn prop_bagging_deterministic_with_seed(
1204            n_estimators in 1usize..10,
1205            random_seed in 0u64..1000,
1206        ) {
1207            let x = array![
1208                [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1209                [5.0, 6.0], [6.0, 7.0], [7.0, 8.0], [8.0, 9.0],
1210            ];
1211            let y = array![0, 0, 1, 1, 2, 2, 0, 1];
1212
1213            // Train two identical classifiers with same seed
1214            let classifier1 = BaggingClassifier::new()
1215                .n_estimators(n_estimators)
1216                .random_state(random_seed)
1217                .fit(&x, &y)
1218                .unwrap();
1219
1220            let classifier2 = BaggingClassifier::new()
1221                .n_estimators(n_estimators)
1222                .random_state(random_seed)
1223                .fit(&x, &y)
1224                .unwrap();
1225
1226            let pred1 = classifier1.predict(&x).unwrap();
1227            let pred2 = classifier2.predict(&x).unwrap();
1228
1229            // Predictions should be identical with same seed
1230            prop_assert_eq!(pred1, pred2);
1231        }
1232
1233        #[test]
1234        fn prop_bagging_feature_importance_normalization(
1235            n_estimators in 1usize..10,
1236            max_features in 1usize..4,
1237        ) {
1238            let x = array![
1239                [1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0],
1240                [4.0, 5.0, 6.0], [5.0, 6.0, 7.0], [6.0, 7.0, 8.0],
1241            ];
1242            let y = array![0, 0, 1, 1, 2, 2];
1243
1244            let classifier = BaggingClassifier::new()
1245                .n_estimators(n_estimators)
1246                .max_features(Some(max_features))
1247                .random_state(42)
1248                .fit(&x, &y)
1249                .unwrap();
1250
1251            let importances = classifier.feature_importances();
1252            let sum: Float = importances.sum();
1253
1254            // Feature importances should sum to 1 (normalized)
1255            prop_assert!((sum - 1.0).abs() < 1e-10);
1256
1257            // All importances should be non-negative
1258            for &importance in importances.iter() {
1259                prop_assert!(importance >= 0.0);
1260            }
1261        }
1262
1263        #[test]
1264        fn prop_bagging_bootstrap_diversity(
1265            n_estimators in 2usize..8,
1266        ) {
1267            let x = array![
1268                [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1269                [5.0, 6.0], [6.0, 7.0], [7.0, 8.0], [8.0, 9.0],
1270                [9.0, 10.0], [10.0, 11.0],
1271            ];
1272            let y = array![0, 0, 1, 1, 2, 2, 0, 1, 2, 0];
1273
1274            let classifier = BaggingClassifier::new()
1275                .n_estimators(n_estimators)
1276                .bootstrap(true)
1277                .random_state(42)
1278                .fit(&x, &y)
1279                .unwrap();
1280
1281            let estimators_samples = classifier.estimators_samples();
1282
1283            // Each estimator should use different bootstrap samples (diversity)
1284            let mut unique_sample_sets = HashSet::new();
1285            for samples in estimators_samples {
1286                let mut sorted_samples = samples.clone();
1287                sorted_samples.sort();
1288                unique_sample_sets.insert(sorted_samples);
1289            }
1290
1291            // Should have some diversity in bootstrap samples
1292            // (not all estimators use exactly the same samples)
1293            prop_assert!(unique_sample_sets.len() >= 1);
1294        }
1295
1296        #[test]
1297        fn prop_bagging_prediction_stability(
1298            n_estimators in 3usize..10,
1299        ) {
1300            let x = array![
1301                [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1302                [5.0, 6.0], [6.0, 7.0],
1303            ];
1304            let y = array![0, 0, 1, 1, 2, 2];
1305
1306            let classifier = BaggingClassifier::new()
1307                .n_estimators(n_estimators)
1308                .random_state(42)
1309                .fit(&x, &y)
1310                .unwrap();
1311
1312            let predictions = classifier.predict(&x).unwrap();
1313
1314            // All predictions should be valid class labels
1315            let classes = classifier.classes();
1316            for &pred in predictions.iter() {
1317                prop_assert!(classes.iter().any(|&c| c == pred));
1318            }
1319
1320            // Number of predictions should match input samples
1321            prop_assert_eq!(predictions.len(), x.nrows());
1322        }
1323
1324        #[test]
1325        fn prop_bagging_oob_score_bounds(
1326            n_estimators in 5usize..15,
1327        ) {
1328            let x = array![
1329                [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1330                [5.0, 6.0], [6.0, 7.0], [7.0, 8.0], [8.0, 9.0],
1331                [9.0, 10.0], [10.0, 11.0], [11.0, 12.0], [12.0, 13.0],
1332            ];
1333            let y = array![0, 0, 1, 1, 2, 2, 0, 1, 2, 0, 1, 2];
1334
1335            let classifier = BaggingClassifier::new()
1336                .n_estimators(n_estimators)
1337                .oob_score(true)
1338                .bootstrap(true)
1339                .random_state(42)
1340                .fit(&x, &y)
1341                .unwrap();
1342
1343            if let Some(oob_score) = classifier.oob_score() {
1344                // OOB score should be between 0 and 1 (accuracy)
1345                prop_assert!(oob_score >= 0.0 && oob_score <= 1.0);
1346            }
1347        }
1348
1349        #[test]
1350        fn prop_bagging_confidence_intervals_bounds(
1351            n_estimators in 3usize..8,
1352            confidence_level in 0.7..0.99,
1353        ) {
1354            let x = array![
1355                [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1356            ];
1357            let y = array![0, 0, 1, 1];
1358
1359            let classifier = BaggingClassifier::new()
1360                .n_estimators(n_estimators)
1361                .confidence_level(confidence_level)
1362                .random_state(42)
1363                .fit(&x, &y)
1364                .unwrap();
1365
1366            let (predictions, confidence_intervals) = classifier.predict_with_confidence(&x).unwrap();
1367
1368            // Check confidence interval properties
1369            for i in 0..predictions.len() {
1370                let lower = confidence_intervals[[i, 0]];
1371                let upper = confidence_intervals[[i, 1]];
1372
1373                // Lower bound should be <= upper bound
1374                prop_assert!(lower <= upper);
1375
1376                // Confidence intervals should be reasonable
1377                prop_assert!(lower.is_finite() && upper.is_finite());
1378            }
1379        }
1380    }
1381}