1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::prelude::*;
9use sklears_core::error::{Result, SklearsError};
10use sklears_core::prelude::Predict;
11use sklears_core::traits::{Estimator, Fit, Trained, Untrained};
12use sklears_core::types::{Float, Int};
13use crate::adaboost::{DecisionTreeClassifier, DecisionTreeRegressor, SplitCriterion};
15#[allow(unused_imports)]
16use std::collections::HashSet;
17use std::marker::PhantomData;
18
19#[cfg(feature = "parallel")]
20use rayon::prelude::*;
21
22#[derive(Debug, Clone)]
24pub struct BaggingConfig {
25 pub n_estimators: usize,
27 pub max_samples: Option<usize>,
29 pub max_features: Option<usize>,
31 pub bootstrap: bool,
33 pub bootstrap_features: bool,
35 pub random_state: Option<u64>,
37 pub oob_score: bool,
39 pub n_jobs: Option<i32>,
41 pub max_depth: Option<usize>,
43 pub min_samples_split: usize,
45 pub min_samples_leaf: usize,
47 pub confidence_level: Float,
49 pub extra_randomized: bool,
51}
52
53impl Default for BaggingConfig {
54 fn default() -> Self {
55 Self {
56 n_estimators: 10,
57 max_samples: None,
58 max_features: None,
59 bootstrap: true,
60 bootstrap_features: false,
61 random_state: None,
62 oob_score: false,
63 n_jobs: None,
64 max_depth: None,
65 min_samples_split: 2,
66 min_samples_leaf: 1,
67 confidence_level: 0.95,
68 extra_randomized: false,
69 }
70 }
71}
72
73pub struct BaggingClassifier<State = Untrained> {
75 config: BaggingConfig,
76 state: PhantomData<State>,
77 estimators_: Option<Vec<DecisionTreeClassifier<Trained>>>,
79 estimators_features_: Option<Vec<Vec<usize>>>,
80 estimators_samples_: Option<Vec<Vec<usize>>>,
81 oob_score_: Option<Float>,
82 oob_prediction_: Option<Array1<Float>>,
83 classes_: Option<Array1<Int>>,
84 n_classes_: Option<usize>,
85 n_features_in_: Option<usize>,
86 feature_importances_: Option<Array1<Float>>,
87}
88
89impl BaggingClassifier<Untrained> {
90 pub fn new() -> Self {
92 Self {
93 config: BaggingConfig::default(),
94 state: PhantomData,
95 estimators_: None,
96 estimators_features_: None,
97 estimators_samples_: None,
98 oob_score_: None,
99 oob_prediction_: None,
100 classes_: None,
101 n_classes_: None,
102 n_features_in_: None,
103 feature_importances_: None,
104 }
105 }
106
107 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
109 self.config.n_estimators = n_estimators;
110 self
111 }
112
113 pub fn max_samples(mut self, max_samples: Option<usize>) -> Self {
115 self.config.max_samples = max_samples;
116 self
117 }
118
119 pub fn max_features(mut self, max_features: Option<usize>) -> Self {
121 self.config.max_features = max_features;
122 self
123 }
124
125 pub fn bootstrap(mut self, bootstrap: bool) -> Self {
127 self.config.bootstrap = bootstrap;
128 self
129 }
130
131 pub fn bootstrap_features(mut self, bootstrap_features: bool) -> Self {
133 self.config.bootstrap_features = bootstrap_features;
134 self
135 }
136
137 pub fn random_state(mut self, random_state: u64) -> Self {
139 self.config.random_state = Some(random_state);
140 self
141 }
142
143 pub fn oob_score(mut self, oob_score: bool) -> Self {
145 self.config.oob_score = oob_score;
146 self
147 }
148
149 pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
151 self.config.max_depth = max_depth;
152 self
153 }
154
155 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
157 self.config.min_samples_split = min_samples_split;
158 self
159 }
160
161 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
163 self.config.min_samples_leaf = min_samples_leaf;
164 self
165 }
166
167 pub fn confidence_level(mut self, confidence_level: Float) -> Self {
169 self.config.confidence_level = confidence_level;
170 self
171 }
172
173 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
175 self.config.n_jobs = n_jobs;
176 self
177 }
178
179 pub fn parallel(mut self) -> Self {
181 self.config.n_jobs = Some(-1); self
183 }
184
185 pub fn extra_randomized(mut self, extra_randomized: bool) -> Self {
187 self.config.extra_randomized = extra_randomized;
188 self
189 }
190
191 pub fn extremely_randomized(mut self) -> Self {
193 self.config.extra_randomized = true;
194 self.config.bootstrap = false; self
196 }
197
198 fn bootstrap_sample(
201 &self,
202 x: &Array2<Float>,
203 y: &Array1<Int>,
204 rng: &mut StdRng,
205 ) -> Result<(Array2<Float>, Array1<Int>, Vec<usize>)> {
206 let n_samples = x.nrows();
207
208 if self.config.extra_randomized {
210 let sample_indices: Vec<usize> = (0..n_samples).collect();
211 return Ok((x.clone(), y.clone(), sample_indices));
212 }
213
214 let max_samples = self.config.max_samples.unwrap_or(n_samples);
215
216 let mut class_indices: std::collections::HashMap<Int, Vec<usize>> =
218 std::collections::HashMap::new();
219 for (idx, &class) in y.iter().enumerate() {
220 class_indices.entry(class).or_default().push(idx);
221 }
222
223 let mut sample_indices = Vec::new();
224
225 if self.config.bootstrap {
226 sample_indices = (0..max_samples)
228 .map(|_| rng.gen_range(0..n_samples))
229 .collect();
230 } else {
231 let mut indices: Vec<usize> = (0..n_samples).collect();
233 indices.shuffle(rng);
234 indices.truncate(max_samples);
235 sample_indices = indices;
236 }
237
238 let mut sampled_classes = std::collections::HashSet::new();
240 for &idx in &sample_indices {
241 sampled_classes.insert(y[idx]);
242 }
243
244 if sampled_classes.len() < 2 && class_indices.len() >= 2 {
246 let mut other_classes: Vec<Int> = class_indices
247 .keys()
248 .filter(|&&c| !sampled_classes.contains(&c))
249 .cloned()
250 .collect();
251
252 other_classes.sort();
254
255 if !other_classes.is_empty() {
256 let other_class = other_classes[0];
258 if let Some(other_indices) = class_indices.get(&other_class) {
259 if !other_indices.is_empty() {
260 let replacement_idx = other_indices[0];
262 if let Some(last) = sample_indices.last_mut() {
263 *last = replacement_idx;
264 }
265 }
266 }
267 }
268 }
269
270 let mut x_bootstrap = Array2::zeros((max_samples, x.ncols()));
271 let mut y_bootstrap = Array1::zeros(max_samples);
272
273 for (i, &idx) in sample_indices.iter().enumerate() {
274 x_bootstrap.row_mut(i).assign(&x.row(idx));
275 y_bootstrap[i] = y[idx];
276 }
277
278 Ok((x_bootstrap, y_bootstrap, sample_indices))
279 }
280
281 fn train_ensemble_parallel(
283 &self,
284 x: &Array2<Float>,
285 y: &Array1<Int>,
286 rng: &mut StdRng,
287 n_features: usize,
288 ) -> Result<(
289 Vec<DecisionTreeClassifier<Trained>>,
290 Vec<Vec<usize>>,
291 Vec<Vec<usize>>,
292 )> {
293 let mut bootstrap_data = Vec::new();
295 for i in 0..self.config.n_estimators {
296 let mut local_rng =
297 StdRng::seed_from_u64(self.config.random_state.unwrap_or(42) + i as u64);
298
299 let (x_bootstrap, y_bootstrap, sample_indices) =
300 self.bootstrap_sample(x, y, &mut local_rng)?;
301 let feature_indices = self.get_feature_indices(n_features, &mut local_rng);
302
303 bootstrap_data.push((x_bootstrap, y_bootstrap, sample_indices, feature_indices));
304 }
305
306 let use_parallel = self.should_use_parallel();
308
309 if use_parallel {
310 #[cfg(feature = "parallel")]
311 {
312 let results: Result<Vec<_>> = bootstrap_data
314 .into_par_iter()
315 .enumerate()
316 .map(
317 |(i, (x_bootstrap, y_bootstrap, sample_indices, feature_indices))| {
318 self.fit_single_estimator(
319 &x_bootstrap,
320 &y_bootstrap,
321 &feature_indices,
322 i,
323 )
324 .map(|estimator| (estimator, feature_indices, sample_indices))
325 },
326 )
327 .collect();
328
329 let fitted_data = results?;
330 let (estimators, estimators_features, estimators_samples) =
331 fitted_data.into_iter().fold(
332 (Vec::new(), Vec::new(), Vec::new()),
333 |(mut e, mut ef, mut es), (estimator, features, samples)| {
334 e.push(estimator);
335 ef.push(features);
336 es.push(samples);
337 (e, ef, es)
338 },
339 );
340
341 Ok((estimators, estimators_features, estimators_samples))
342 }
343 #[cfg(not(feature = "parallel"))]
344 {
345 self.train_ensemble_sequential(bootstrap_data)
347 }
348 } else {
349 self.train_ensemble_sequential(bootstrap_data)
351 }
352 }
353
354 fn train_ensemble_sequential(
356 &self,
357 bootstrap_data: Vec<(Array2<Float>, Array1<Int>, Vec<usize>, Vec<usize>)>,
358 ) -> Result<(
359 Vec<DecisionTreeClassifier<Trained>>,
360 Vec<Vec<usize>>,
361 Vec<Vec<usize>>,
362 )> {
363 let mut estimators = Vec::new();
364 let mut estimators_features = Vec::new();
365 let mut estimators_samples = Vec::new();
366
367 for (i, (x_bootstrap, y_bootstrap, sample_indices, feature_indices)) in
368 bootstrap_data.into_iter().enumerate()
369 {
370 let fitted_tree =
371 self.fit_single_estimator(&x_bootstrap, &y_bootstrap, &feature_indices, i)?;
372
373 estimators.push(fitted_tree);
374 estimators_features.push(feature_indices);
375 estimators_samples.push(sample_indices);
376 }
377
378 Ok((estimators, estimators_features, estimators_samples))
379 }
380
381 fn fit_single_estimator(
383 &self,
384 x_bootstrap: &Array2<Float>,
385 y_bootstrap: &Array1<Int>,
386 feature_indices: &[usize],
387 estimator_index: usize,
388 ) -> Result<DecisionTreeClassifier<Trained>> {
389 let mut x_features = Array2::zeros((x_bootstrap.nrows(), feature_indices.len()));
391 for (j, &feature_idx) in feature_indices.iter().enumerate() {
392 x_features
393 .column_mut(j)
394 .assign(&x_bootstrap.column(feature_idx));
395 }
396
397 let mut tree = DecisionTreeClassifier::new()
399 .criterion(SplitCriterion::Gini)
400 .min_samples_split(self.config.min_samples_split)
401 .min_samples_leaf(self.config.min_samples_leaf);
402
403 if let Some(max_depth) = self.config.max_depth {
404 tree = tree.max_depth(max_depth);
405 }
406
407 if let Some(seed) = self.config.random_state.map(|s| s + estimator_index as u64) {
408 tree = tree.random_state(Some(seed));
409 }
410
411 tree.fit(&x_features, y_bootstrap)
413 }
414
415 fn should_use_parallel(&self) -> bool {
417 match self.config.n_jobs {
418 Some(n) if n != 1 => true, None => false, _ => false, }
422 }
423
424 fn get_feature_indices(&self, n_features: usize, rng: &mut StdRng) -> Vec<usize> {
426 let max_features = self.config.max_features.unwrap_or(n_features);
427 let mut feature_indices: Vec<usize> = (0..n_features).collect();
428
429 if self.config.bootstrap_features {
430 feature_indices = (0..max_features)
432 .map(|_| rng.gen_range(0..n_features))
433 .collect();
434 } else {
435 feature_indices.shuffle(rng);
437 feature_indices.truncate(max_features);
438 }
439
440 feature_indices.sort_unstable();
441 feature_indices
442 }
443
444 fn calculate_oob_predictions(
446 &self,
447 x: &Array2<Float>,
448 y: &Array1<Int>,
449 estimators: &[DecisionTreeClassifier<Trained>],
450 estimators_features: &[Vec<usize>],
451 estimators_samples: &[Vec<usize>],
452 ) -> Result<Float> {
453 let n_samples = x.nrows();
454 let mut oob_predictions: Array1<Float> = Array1::zeros(n_samples);
455 let mut oob_counts: Array1<Float> = Array1::zeros(n_samples);
456
457 for (estimator_idx, (estimator, (features, samples))) in estimators
458 .iter()
459 .zip(estimators_features.iter().zip(estimators_samples.iter()))
460 .enumerate()
461 {
462 let mut oob_mask = vec![true; n_samples];
464 for &sample_idx in samples {
465 if sample_idx < n_samples {
466 oob_mask[sample_idx] = false;
467 }
468 }
469
470 for (sample_idx, &is_oob) in oob_mask.iter().enumerate() {
472 if is_oob {
473 let x_sample = x.row(sample_idx);
475 let x_features = Array2::from_shape_vec(
476 (1, features.len()),
477 features.iter().map(|&f| x_sample[f]).collect(),
478 )
479 .map_err(|_| {
480 SklearsError::InvalidInput("Failed to create feature subset".to_string())
481 })?;
482
483 let pred = estimator.predict(&x_features)?;
484 oob_predictions[sample_idx] += pred[0] as Float;
485 oob_counts[sample_idx] += 1.0;
486 }
487 }
488 }
489
490 let mut correct = 0;
492 let mut total = 0;
493
494 for i in 0..n_samples {
495 if oob_counts[i] > 0.0 {
496 let ratio: Float = oob_predictions[i] / oob_counts[i];
497 let predicted_class: Int = ratio.round() as Int;
498 if predicted_class == y[i] {
499 correct += 1;
500 }
501 total += 1;
502 }
503 }
504
505 if total == 0 {
506 Ok(0.0)
507 } else {
508 Ok(correct as Float / total as Float)
509 }
510 }
511}
512
513impl Fit<Array2<Float>, Array1<Int>> for BaggingClassifier<Untrained> {
514 type Fitted = BaggingClassifier<Trained>;
515
516 fn fit(self, x: &Array2<Float>, y: &Array1<Int>) -> Result<Self::Fitted> {
517 let (n_samples, n_features) = x.dim();
518
519 if n_samples != y.len() {
520 return Err(SklearsError::ShapeMismatch {
521 expected: format!("X.shape[0] = {}", n_samples),
522 actual: format!("y.shape[0] = {}", y.len()),
523 });
524 }
525
526 if n_samples == 0 {
527 return Err(SklearsError::InvalidInput(
528 "Cannot fit bagging on empty dataset".to_string(),
529 ));
530 }
531
532 let mut unique_classes: Vec<Int> = y.iter().cloned().collect();
534 unique_classes.sort_unstable();
535 unique_classes.dedup();
536 let classes = Array1::from_vec(unique_classes);
537 let n_classes = classes.len();
538
539 if n_classes < 2 {
540 return Err(SklearsError::InvalidInput(
541 "Bagging requires at least 2 classes".to_string(),
542 ));
543 }
544
545 let mut rng = match self.config.random_state {
547 Some(seed) => StdRng::seed_from_u64(seed),
548 None => StdRng::seed_from_u64(42), };
550
551 let (estimators, estimators_features, estimators_samples) =
553 self.train_ensemble_parallel(x, y, &mut rng, n_features)?;
554
555 let oob_score = if self.config.oob_score {
557 Some(self.calculate_oob_predictions(
558 x,
559 y,
560 &estimators,
561 &estimators_features,
562 &estimators_samples,
563 )?)
564 } else {
565 None
566 };
567
568 let mut feature_importances = Array1::zeros(n_features);
570 for (estimator, features) in estimators.iter().zip(estimators_features.iter()) {
571 let tree_importance = 1.0 / features.len() as Float;
573 for &feature_idx in features {
574 feature_importances[feature_idx] += tree_importance;
575 }
576 }
577
578 let total_importance = feature_importances.sum();
580 if total_importance > 0.0 {
581 feature_importances /= total_importance;
582 }
583
584 Ok(BaggingClassifier {
585 config: self.config,
586 state: PhantomData,
587 estimators_: Some(estimators),
588 estimators_features_: Some(estimators_features.to_vec()),
589 estimators_samples_: Some(estimators_samples.to_vec()),
590 oob_score_: oob_score,
591 oob_prediction_: None,
592 classes_: Some(classes),
593 n_classes_: Some(n_classes),
594 n_features_in_: Some(n_features),
595 feature_importances_: Some(feature_importances),
596 })
597 }
598}
599
600impl BaggingClassifier<Trained> {
601 pub fn estimators(&self) -> &[DecisionTreeClassifier<Trained>] {
603 self.estimators_
604 .as_ref()
605 .expect("BaggingClassifier should be fitted")
606 }
607
608 pub fn estimators_features(&self) -> &[Vec<usize>] {
610 self.estimators_features_
611 .as_ref()
612 .expect("BaggingClassifier should be fitted")
613 }
614
615 pub fn estimators_samples(&self) -> &[Vec<usize>] {
617 self.estimators_samples_
618 .as_ref()
619 .expect("BaggingClassifier should be fitted")
620 }
621
622 pub fn oob_score(&self) -> Option<Float> {
624 self.oob_score_
625 }
626
627 pub fn classes(&self) -> &Array1<Int> {
629 self.classes_
630 .as_ref()
631 .expect("BaggingClassifier should be fitted")
632 }
633
634 pub fn n_classes(&self) -> usize {
636 self.n_classes_.expect("BaggingClassifier should be fitted")
637 }
638
639 pub fn n_features_in(&self) -> usize {
641 self.n_features_in_
642 .expect("BaggingClassifier should be fitted")
643 }
644
645 pub fn feature_importances(&self) -> &Array1<Float> {
647 self.feature_importances_
648 .as_ref()
649 .expect("BaggingClassifier should be fitted")
650 }
651
652 pub fn predict_with_confidence(
654 &self,
655 x: &Array2<Float>,
656 ) -> Result<(Array1<Int>, Array2<Float>)> {
657 let (n_samples, n_features) = x.dim();
658
659 if n_features != self.n_features_in() {
660 return Err(SklearsError::FeatureMismatch {
661 expected: self.n_features_in(),
662 actual: n_features,
663 });
664 }
665
666 let estimators = self.estimators();
667 let estimators_features = self.estimators_features();
668 let classes = self.classes();
669 let n_classes = self.n_classes();
670 let n_estimators = estimators.len();
671
672 let mut all_predictions = Array2::zeros((n_samples, n_estimators));
673
674 for (estimator_idx, (estimator, features)) in estimators
676 .iter()
677 .zip(estimators_features.iter())
678 .enumerate()
679 {
680 let mut x_features = Array2::zeros((n_samples, features.len()));
682 for (j, &feature_idx) in features.iter().enumerate() {
683 x_features.column_mut(j).assign(&x.column(feature_idx));
684 }
685
686 let predictions = estimator.predict(&x_features)?;
687
688 if predictions.len() != n_samples {
690 return Err(SklearsError::ShapeMismatch {
691 expected: format!("{} predictions", n_samples),
692 actual: format!("{} predictions", predictions.len()),
693 });
694 }
695
696 for i in 0..n_samples {
697 all_predictions[[i, estimator_idx]] = predictions[i] as Float;
698 }
699 }
700
701 let mut final_predictions = Array1::zeros(n_samples);
703 let mut confidence_intervals = Array2::zeros((n_samples, 2)); for i in 0..n_samples {
706 let sample_predictions = all_predictions.row(i);
707
708 let mut class_counts = vec![0; n_classes];
710 for &pred in sample_predictions {
711 let class_idx = classes.iter().position(|&c| c == pred as Int).unwrap_or(0);
712 class_counts[class_idx] += 1;
713 }
714
715 let max_class_idx = class_counts
716 .iter()
717 .enumerate()
718 .max_by(|(_, a), (_, b)| a.cmp(b))
719 .map(|(idx, _)| idx)
720 .unwrap_or(0);
721 final_predictions[i] = classes[max_class_idx];
722
723 let mut sorted_predictions: Vec<Float> = sample_predictions.to_vec();
725 sorted_predictions.sort_by(|a, b| a.partial_cmp(b).unwrap());
726
727 let alpha = 1.0 - self.config.confidence_level;
728 let lower_idx = ((alpha / 2.0) * n_estimators as Float) as usize;
729 let upper_idx = ((1.0 - alpha / 2.0) * n_estimators as Float) as usize;
730
731 confidence_intervals[[i, 0]] = sorted_predictions[lower_idx.min(n_estimators - 1)];
732 confidence_intervals[[i, 1]] = sorted_predictions[upper_idx.min(n_estimators - 1)];
733 }
734
735 Ok((final_predictions, confidence_intervals))
736 }
737}
738
739impl Predict<Array2<Float>, Array1<Int>> for BaggingClassifier<Trained> {
740 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Int>> {
741 let (n_samples, n_features) = x.dim();
742
743 if n_features != self.n_features_in() {
744 return Err(SklearsError::FeatureMismatch {
745 expected: self.n_features_in(),
746 actual: n_features,
747 });
748 }
749
750 let estimators = self.estimators();
751 let estimators_features = self.estimators_features();
752 let classes = self.classes();
753 let n_classes = self.n_classes();
754
755 let mut class_votes = Array2::zeros((n_samples, n_classes));
756
757 for (estimator, features) in estimators.iter().zip(estimators_features.iter()) {
759 let mut x_features = Array2::zeros((n_samples, features.len()));
761 for (j, &feature_idx) in features.iter().enumerate() {
762 x_features.column_mut(j).assign(&x.column(feature_idx));
763 }
764
765 let predictions = estimator.predict(&x_features)?;
766
767 for (i, &pred) in predictions.iter().enumerate() {
769 if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
770 class_votes[[i, class_idx]] += 1.0;
771 }
772 }
773 }
774
775 let mut final_predictions = Array1::zeros(n_samples);
777 for i in 0..n_samples {
778 let max_idx = class_votes
779 .row(i)
780 .iter()
781 .enumerate()
782 .max_by(|(_, a): &(_, &Float), (_, b): &(_, &Float)| a.partial_cmp(b).unwrap())
783 .map(|(idx, _)| idx)
784 .unwrap_or(0);
785 final_predictions[i] = classes[max_idx];
786 }
787
788 Ok(final_predictions)
789 }
790}
791
792impl Default for BaggingClassifier<Untrained> {
793 fn default() -> Self {
794 Self::new()
795 }
796}
797
798impl<State> Estimator<State> for BaggingClassifier<State> {
799 type Config = BaggingConfig;
800 type Error = SklearsError;
801 type Float = Float;
802
803 fn config(&self) -> &Self::Config {
804 &self.config
805 }
806
807 fn validate_config(&self) -> Result<()> {
808 if self.config.n_estimators == 0 {
809 return Err(SklearsError::InvalidInput(
810 "n_estimators must be greater than 0".to_string(),
811 ));
812 }
813
814 if let Some(max_samples) = self.config.max_samples {
815 if max_samples == 0 {
816 return Err(SklearsError::InvalidInput(
817 "max_samples must be greater than 0".to_string(),
818 ));
819 }
820 }
821
822 if let Some(max_features) = self.config.max_features {
823 if max_features == 0 {
824 return Err(SklearsError::InvalidInput(
825 "max_features must be greater than 0".to_string(),
826 ));
827 }
828 }
829
830 if self.config.min_samples_split < 2 {
831 return Err(SklearsError::InvalidInput(
832 "min_samples_split must be at least 2".to_string(),
833 ));
834 }
835
836 if self.config.min_samples_leaf < 1 {
837 return Err(SklearsError::InvalidInput(
838 "min_samples_leaf must be at least 1".to_string(),
839 ));
840 }
841
842 if self.config.confidence_level <= 0.0 || self.config.confidence_level >= 1.0 {
843 return Err(SklearsError::InvalidInput(
844 "confidence_level must be between 0.0 and 1.0".to_string(),
845 ));
846 }
847
848 Ok(())
849 }
850
851 fn metadata(&self) -> sklears_core::traits::EstimatorMetadata {
852 sklears_core::traits::EstimatorMetadata {
853 name: "BaggingClassifier".to_string(),
854 version: env!("CARGO_PKG_VERSION").to_string(),
855 description: "Bootstrap aggregating (bagging) classifier".to_string(),
856 supports_sparse: false,
857 supports_multiclass: true,
858 supports_multilabel: false,
859 requires_positive_input: false,
860 supports_online_learning: false,
861 supports_feature_importance: true,
862 memory_complexity: sklears_core::traits::MemoryComplexity::Linear,
863 time_complexity: sklears_core::traits::TimeComplexity::LogLinear,
864 }
865 }
866}
867
868pub struct BaggingRegressor<State = Untrained> {
870 config: BaggingConfig,
871 state: PhantomData<State>,
872 estimators_: Option<Vec<DecisionTreeRegressor<Trained>>>,
874 estimators_features_: Option<Vec<Vec<usize>>>,
875 estimators_samples_: Option<Vec<Vec<usize>>>,
876 oob_score_: Option<Float>,
877 n_features_in_: Option<usize>,
878 feature_importances_: Option<Array1<Float>>,
879}
880
881impl BaggingRegressor<Untrained> {
882 pub fn new() -> Self {
884 Self {
885 config: BaggingConfig::default(),
886 state: PhantomData,
887 estimators_: None,
888 estimators_features_: None,
889 estimators_samples_: None,
890 oob_score_: None,
891 n_features_in_: None,
892 feature_importances_: None,
893 }
894 }
895
896 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
898 self.config.n_estimators = n_estimators;
899 self
900 }
901
902 pub fn random_state(mut self, random_state: u64) -> Self {
904 self.config.random_state = Some(random_state);
905 self
906 }
907
908 pub fn oob_score(mut self, oob_score: bool) -> Self {
910 self.config.oob_score = oob_score;
911 self
912 }
913}
914
915impl BaggingRegressor<Trained> {
916 pub fn oob_score(&self) -> Option<Float> {
918 self.oob_score_
919 }
920
921 pub fn n_features_in(&self) -> usize {
923 self.n_features_in_
924 .expect("BaggingRegressor should be fitted")
925 }
926
927 pub fn feature_importances(&self) -> &Array1<Float> {
929 self.feature_importances_
930 .as_ref()
931 .expect("BaggingRegressor should be fitted")
932 }
933}
934
935impl Default for BaggingRegressor<Untrained> {
936 fn default() -> Self {
937 Self::new()
938 }
939}
940
941impl<State> Estimator<State> for BaggingRegressor<State> {
942 type Config = BaggingConfig;
943 type Error = SklearsError;
944 type Float = Float;
945
946 fn config(&self) -> &Self::Config {
947 &self.config
948 }
949
950 fn validate_config(&self) -> Result<()> {
951 if self.config.n_estimators == 0 {
952 return Err(SklearsError::InvalidInput(
953 "n_estimators must be greater than 0".to_string(),
954 ));
955 }
956
957 if let Some(max_samples) = self.config.max_samples {
958 if max_samples == 0 {
959 return Err(SklearsError::InvalidInput(
960 "max_samples must be greater than 0".to_string(),
961 ));
962 }
963 }
964
965 if let Some(max_features) = self.config.max_features {
966 if max_features == 0 {
967 return Err(SklearsError::InvalidInput(
968 "max_features must be greater than 0".to_string(),
969 ));
970 }
971 }
972
973 if self.config.min_samples_split < 2 {
974 return Err(SklearsError::InvalidInput(
975 "min_samples_split must be at least 2".to_string(),
976 ));
977 }
978
979 if self.config.min_samples_leaf < 1 {
980 return Err(SklearsError::InvalidInput(
981 "min_samples_leaf must be at least 1".to_string(),
982 ));
983 }
984
985 if self.config.confidence_level <= 0.0 || self.config.confidence_level >= 1.0 {
986 return Err(SklearsError::InvalidInput(
987 "confidence_level must be between 0.0 and 1.0".to_string(),
988 ));
989 }
990
991 Ok(())
992 }
993
994 fn metadata(&self) -> sklears_core::traits::EstimatorMetadata {
995 sklears_core::traits::EstimatorMetadata {
996 name: "BaggingRegressor".to_string(),
997 version: env!("CARGO_PKG_VERSION").to_string(),
998 description: "Bootstrap aggregating (bagging) regressor".to_string(),
999 supports_sparse: false,
1000 supports_multiclass: false,
1001 supports_multilabel: false,
1002 requires_positive_input: false,
1003 supports_online_learning: false,
1004 supports_feature_importance: true,
1005 memory_complexity: sklears_core::traits::MemoryComplexity::Linear,
1006 time_complexity: sklears_core::traits::TimeComplexity::LogLinear,
1007 }
1008 }
1009}
1010
1011#[allow(non_snake_case)]
1012#[cfg(test)]
1013mod tests {
1014 use super::*;
1015 use scirs2_core::ndarray::array;
1016 use sklears_core::traits::Predict;
1017
1018 use proptest::prelude::*;
1020
1021 #[test]
1022 fn test_bagging_classifier_creation() {
1023 let classifier = BaggingClassifier::new()
1024 .n_estimators(20)
1025 .random_state(42)
1026 .oob_score(true);
1027
1028 assert_eq!(classifier.config.n_estimators, 20);
1029 assert_eq!(classifier.config.random_state, Some(42));
1030 assert_eq!(classifier.config.oob_score, true);
1031 }
1032
1033 #[test]
1034 fn test_bagging_classifier_fit_predict() {
1035 let x = array![
1036 [1.0, 2.0],
1037 [2.0, 3.0],
1038 [3.0, 4.0],
1039 [4.0, 5.0],
1040 [5.0, 6.0],
1041 [6.0, 7.0],
1042 [7.0, 8.0],
1043 [8.0, 9.0],
1044 ];
1045 let y = array![0, 0, 1, 1, 2, 2, 0, 1];
1046
1047 let classifier = BaggingClassifier::new().n_estimators(5).random_state(42);
1048
1049 let fitted = classifier.fit(&x, &y).unwrap();
1050 let predictions = fitted.predict(&x).unwrap();
1051
1052 assert_eq!(predictions.len(), 8);
1053 assert_eq!(fitted.n_classes(), 3);
1054 assert_eq!(fitted.classes().len(), 3);
1055 assert_eq!(fitted.n_features_in(), 2);
1056 }
1057
1058 #[test]
1059 fn test_bagging_classifier_with_oob() {
1060 let x = array![
1061 [1.0, 2.0],
1062 [2.0, 3.0],
1063 [3.0, 4.0],
1064 [4.0, 5.0],
1065 [5.0, 6.0],
1066 [6.0, 7.0],
1067 [7.0, 8.0],
1068 [8.0, 9.0],
1069 [9.0, 10.0],
1070 [10.0, 11.0],
1071 ];
1072 let y = array![0, 0, 1, 1, 2, 2, 0, 1, 2, 0];
1073
1074 let classifier = BaggingClassifier::new()
1075 .n_estimators(10)
1076 .random_state(42)
1077 .oob_score(true)
1078 .bootstrap(true);
1079
1080 let fitted = classifier.fit(&x, &y).unwrap();
1081
1082 assert!(fitted.oob_score().is_some());
1083 let oob_score = fitted.oob_score().unwrap();
1084 assert!(oob_score >= 0.0 && oob_score <= 1.0);
1085
1086 let predictions = fitted.predict(&x).unwrap();
1087 assert_eq!(predictions.len(), 10);
1088 }
1089
1090 #[test]
1091 fn test_bagging_classifier_feature_bagging() {
1092 let x = array![
1093 [1.0, 2.0, 3.0, 4.0],
1094 [2.0, 3.0, 4.0, 5.0],
1095 [3.0, 4.0, 5.0, 6.0],
1096 [4.0, 5.0, 6.0, 7.0],
1097 [5.0, 6.0, 7.0, 8.0],
1098 [6.0, 7.0, 8.0, 9.0],
1099 ];
1100 let y = array![0, 0, 1, 1, 2, 2];
1101
1102 let classifier = BaggingClassifier::new()
1103 .n_estimators(5)
1104 .max_features(Some(2)) .bootstrap_features(false)
1106 .random_state(42);
1107
1108 let fitted = classifier.fit(&x, &y).unwrap();
1109 let predictions = fitted.predict(&x).unwrap();
1110
1111 assert_eq!(predictions.len(), 6);
1112 assert_eq!(fitted.n_features_in(), 4);
1113
1114 let importances = fitted.feature_importances();
1116 assert_eq!(importances.len(), 4);
1117
1118 let sum: Float = importances.sum();
1120 assert!((sum - 1.0).abs() < 1e-10);
1121 }
1122
1123 #[test]
1124 fn test_bagging_classifier_confidence_intervals() {
1125 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
1126 let y = array![0, 0, 1, 1];
1127
1128 let classifier = BaggingClassifier::new()
1129 .n_estimators(10)
1130 .random_state(42)
1131 .confidence_level(0.95);
1132
1133 let fitted = classifier.fit(&x, &y).unwrap();
1134 let (predictions, confidence_intervals) = fitted.predict_with_confidence(&x).unwrap();
1135
1136 assert_eq!(predictions.len(), 4);
1137 assert_eq!(confidence_intervals.dim(), (4, 2));
1138
1139 for i in 0..4 {
1141 assert!(confidence_intervals[[i, 0]] <= confidence_intervals[[i, 1]]);
1142 }
1143 }
1144
1145 #[test]
1146 fn test_bagging_regressor_creation() {
1147 let regressor = BaggingRegressor::new().n_estimators(15).random_state(123);
1148
1149 assert_eq!(regressor.config.n_estimators, 15);
1150 assert_eq!(regressor.config.random_state, Some(123));
1151 }
1152
1153 #[test]
1154 fn test_bagging_config_default() {
1155 let config = BaggingConfig::default();
1156
1157 assert_eq!(config.n_estimators, 10);
1158 assert_eq!(config.bootstrap, true);
1159 assert_eq!(config.bootstrap_features, false);
1160 assert_eq!(config.oob_score, false);
1161 assert_eq!(config.random_state, None);
1162 assert_eq!(config.min_samples_split, 2);
1163 assert_eq!(config.min_samples_leaf, 1);
1164 assert_eq!(config.confidence_level, 0.95);
1165 }
1166
1167 #[test]
1168 fn test_bagging_classifier_invalid_input() {
1169 let classifier = BaggingClassifier::new();
1171 let x = Array2::zeros((0, 2));
1172 let y = Array1::zeros(0);
1173 assert!(classifier.fit(&x, &y).is_err());
1174
1175 let classifier = BaggingClassifier::new();
1177 let x = Array2::zeros((3, 2));
1178 let y = Array1::zeros(2);
1179 assert!(classifier.fit(&x, &y).is_err());
1180
1181 let classifier = BaggingClassifier::new();
1183 let x = array![[1.0, 2.0], [3.0, 4.0]];
1184 let y = array![0, 0];
1185 assert!(classifier.fit(&x, &y).is_err());
1186 }
1187
1188 #[test]
1189 fn test_bagging_classifier_feature_mismatch() {
1190 let x_train = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1191 let y_train = array![0, 1];
1192 let x_test = array![[1.0, 2.0]]; let classifier = BaggingClassifier::new();
1195 let fitted = classifier.fit(&x_train, &y_train).unwrap();
1196 assert!(fitted.predict(&x_test).is_err());
1197 }
1198
1199 proptest! {
1202 #[test]
1203 fn prop_bagging_deterministic_with_seed(
1204 n_estimators in 1usize..10,
1205 random_seed in 0u64..1000,
1206 ) {
1207 let x = array![
1208 [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1209 [5.0, 6.0], [6.0, 7.0], [7.0, 8.0], [8.0, 9.0],
1210 ];
1211 let y = array![0, 0, 1, 1, 2, 2, 0, 1];
1212
1213 let classifier1 = BaggingClassifier::new()
1215 .n_estimators(n_estimators)
1216 .random_state(random_seed)
1217 .fit(&x, &y)
1218 .unwrap();
1219
1220 let classifier2 = BaggingClassifier::new()
1221 .n_estimators(n_estimators)
1222 .random_state(random_seed)
1223 .fit(&x, &y)
1224 .unwrap();
1225
1226 let pred1 = classifier1.predict(&x).unwrap();
1227 let pred2 = classifier2.predict(&x).unwrap();
1228
1229 prop_assert_eq!(pred1, pred2);
1231 }
1232
1233 #[test]
1234 fn prop_bagging_feature_importance_normalization(
1235 n_estimators in 1usize..10,
1236 max_features in 1usize..4,
1237 ) {
1238 let x = array![
1239 [1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0],
1240 [4.0, 5.0, 6.0], [5.0, 6.0, 7.0], [6.0, 7.0, 8.0],
1241 ];
1242 let y = array![0, 0, 1, 1, 2, 2];
1243
1244 let classifier = BaggingClassifier::new()
1245 .n_estimators(n_estimators)
1246 .max_features(Some(max_features))
1247 .random_state(42)
1248 .fit(&x, &y)
1249 .unwrap();
1250
1251 let importances = classifier.feature_importances();
1252 let sum: Float = importances.sum();
1253
1254 prop_assert!((sum - 1.0).abs() < 1e-10);
1256
1257 for &importance in importances.iter() {
1259 prop_assert!(importance >= 0.0);
1260 }
1261 }
1262
1263 #[test]
1264 fn prop_bagging_bootstrap_diversity(
1265 n_estimators in 2usize..8,
1266 ) {
1267 let x = array![
1268 [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1269 [5.0, 6.0], [6.0, 7.0], [7.0, 8.0], [8.0, 9.0],
1270 [9.0, 10.0], [10.0, 11.0],
1271 ];
1272 let y = array![0, 0, 1, 1, 2, 2, 0, 1, 2, 0];
1273
1274 let classifier = BaggingClassifier::new()
1275 .n_estimators(n_estimators)
1276 .bootstrap(true)
1277 .random_state(42)
1278 .fit(&x, &y)
1279 .unwrap();
1280
1281 let estimators_samples = classifier.estimators_samples();
1282
1283 let mut unique_sample_sets = HashSet::new();
1285 for samples in estimators_samples {
1286 let mut sorted_samples = samples.clone();
1287 sorted_samples.sort();
1288 unique_sample_sets.insert(sorted_samples);
1289 }
1290
1291 prop_assert!(unique_sample_sets.len() >= 1);
1294 }
1295
1296 #[test]
1297 fn prop_bagging_prediction_stability(
1298 n_estimators in 3usize..10,
1299 ) {
1300 let x = array![
1301 [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1302 [5.0, 6.0], [6.0, 7.0],
1303 ];
1304 let y = array![0, 0, 1, 1, 2, 2];
1305
1306 let classifier = BaggingClassifier::new()
1307 .n_estimators(n_estimators)
1308 .random_state(42)
1309 .fit(&x, &y)
1310 .unwrap();
1311
1312 let predictions = classifier.predict(&x).unwrap();
1313
1314 let classes = classifier.classes();
1316 for &pred in predictions.iter() {
1317 prop_assert!(classes.iter().any(|&c| c == pred));
1318 }
1319
1320 prop_assert_eq!(predictions.len(), x.nrows());
1322 }
1323
1324 #[test]
1325 fn prop_bagging_oob_score_bounds(
1326 n_estimators in 5usize..15,
1327 ) {
1328 let x = array![
1329 [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1330 [5.0, 6.0], [6.0, 7.0], [7.0, 8.0], [8.0, 9.0],
1331 [9.0, 10.0], [10.0, 11.0], [11.0, 12.0], [12.0, 13.0],
1332 ];
1333 let y = array![0, 0, 1, 1, 2, 2, 0, 1, 2, 0, 1, 2];
1334
1335 let classifier = BaggingClassifier::new()
1336 .n_estimators(n_estimators)
1337 .oob_score(true)
1338 .bootstrap(true)
1339 .random_state(42)
1340 .fit(&x, &y)
1341 .unwrap();
1342
1343 if let Some(oob_score) = classifier.oob_score() {
1344 prop_assert!(oob_score >= 0.0 && oob_score <= 1.0);
1346 }
1347 }
1348
1349 #[test]
1350 fn prop_bagging_confidence_intervals_bounds(
1351 n_estimators in 3usize..8,
1352 confidence_level in 0.7..0.99,
1353 ) {
1354 let x = array![
1355 [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
1356 ];
1357 let y = array![0, 0, 1, 1];
1358
1359 let classifier = BaggingClassifier::new()
1360 .n_estimators(n_estimators)
1361 .confidence_level(confidence_level)
1362 .random_state(42)
1363 .fit(&x, &y)
1364 .unwrap();
1365
1366 let (predictions, confidence_intervals) = classifier.predict_with_confidence(&x).unwrap();
1367
1368 for i in 0..predictions.len() {
1370 let lower = confidence_intervals[[i, 0]];
1371 let upper = confidence_intervals[[i, 1]];
1372
1373 prop_assert!(lower <= upper);
1375
1376 prop_assert!(lower.is_finite() && upper.is_finite());
1378 }
1379 }
1380 }
1381}