1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::SliceRandomExt; use sklears_core::{
9 error::{Result, SklearsError},
10 traits::{Estimator, Fit, Predict, Trained, Untrained},
11 types::Float,
12};
13use std::collections::HashMap;
14use std::marker::PhantomData;
15
16use smartcore::ensemble::random_forest_classifier::{
18 RandomForestClassifier as SmartCoreClassifier, RandomForestClassifierParameters,
19};
20use smartcore::ensemble::random_forest_regressor::{
21 RandomForestRegressor as SmartCoreRegressor, RandomForestRegressorParameters,
22};
23use smartcore::tree::decision_tree_classifier::SplitCriterion as ClassifierCriterion;
24use smartcore::linalg::basic::matrix::DenseMatrix;
26
27use crate::{ndarray_to_dense_matrix, MaxFeatures, SplitCriterion};
28
29#[derive(Debug, Clone)]
31pub enum ClassWeight {
32 None,
34 Balanced,
36 Custom(HashMap<i32, f64>),
38}
39
40#[derive(Debug, Clone, Copy)]
42pub enum SamplingStrategy {
43 Bootstrap,
45 BalancedBootstrap,
47 Stratified,
49 SMOTEBootstrap,
51}
52
53#[derive(Debug, Clone)]
55pub struct RandomForestConfig {
56 pub n_estimators: usize,
58 pub criterion: SplitCriterion,
60 pub max_depth: Option<usize>,
62 pub min_samples_split: usize,
64 pub min_samples_leaf: usize,
66 pub max_features: MaxFeatures,
68 pub bootstrap: bool,
70 pub oob_score: bool,
72 pub random_state: Option<u64>,
74 pub n_jobs: Option<i32>,
76 pub min_weight_fraction_leaf: f64,
78 pub max_leaf_nodes: Option<usize>,
80 pub min_impurity_decrease: f64,
82 pub warm_start: bool,
84 pub class_weight: ClassWeight,
86 pub sampling_strategy: SamplingStrategy,
88}
89
90impl Default for RandomForestConfig {
91 fn default() -> Self {
92 Self {
93 n_estimators: 100,
94 criterion: SplitCriterion::Gini,
95 max_depth: None,
96 min_samples_split: 2,
97 min_samples_leaf: 1,
98 max_features: MaxFeatures::Sqrt,
99 bootstrap: true,
100 oob_score: false,
101 random_state: None,
102 n_jobs: None,
103 min_weight_fraction_leaf: 0.0,
104 max_leaf_nodes: None,
105 min_impurity_decrease: 0.0,
106 warm_start: false,
107 class_weight: ClassWeight::None,
108 sampling_strategy: SamplingStrategy::Bootstrap,
109 }
110 }
111}
112
113pub struct RandomForestClassifier<State = Untrained> {
115 config: RandomForestConfig,
116 state: PhantomData<State>,
117 model_: Option<SmartCoreClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>>,
119 classes_: Option<Array1<i32>>,
120 n_classes_: Option<usize>,
121 n_features_: Option<usize>,
122 #[allow(dead_code)]
123 n_outputs_: Option<usize>,
124 oob_score_: Option<f64>,
125 oob_decision_function_: Option<Array2<f64>>,
126 proximity_matrix_: Option<Array2<f64>>,
127}
128
129impl RandomForestClassifier<Untrained> {
130 pub fn new() -> Self {
132 Self {
133 config: RandomForestConfig::default(),
134 state: PhantomData,
135 model_: None,
136 classes_: None,
137 n_classes_: None,
138 n_features_: None,
139 n_outputs_: None,
140 oob_score_: None,
141 oob_decision_function_: None,
142 proximity_matrix_: None,
143 }
144 }
145
146 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
148 self.config.n_estimators = n_estimators;
149 self
150 }
151
152 pub fn criterion(mut self, criterion: SplitCriterion) -> Self {
154 self.config.criterion = criterion;
155 self
156 }
157
158 pub fn max_depth(mut self, max_depth: usize) -> Self {
160 self.config.max_depth = Some(max_depth);
161 self
162 }
163
164 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
166 self.config.min_samples_split = min_samples_split;
167 self
168 }
169
170 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
172 self.config.min_samples_leaf = min_samples_leaf;
173 self
174 }
175
176 pub fn max_features(mut self, max_features: MaxFeatures) -> Self {
178 self.config.max_features = max_features;
179 self
180 }
181
182 pub fn bootstrap(mut self, bootstrap: bool) -> Self {
184 self.config.bootstrap = bootstrap;
185 self
186 }
187
188 pub fn oob_score(mut self, oob_score: bool) -> Self {
190 self.config.oob_score = oob_score;
191 self
192 }
193
194 pub fn class_weight(mut self, class_weight: ClassWeight) -> Self {
196 self.config.class_weight = class_weight;
197 self
198 }
199
200 pub fn sampling_strategy(mut self, sampling_strategy: SamplingStrategy) -> Self {
202 self.config.sampling_strategy = sampling_strategy;
203 self
204 }
205
206 pub fn random_state(mut self, seed: u64) -> Self {
208 self.config.random_state = Some(seed);
209 self
210 }
211
212 pub fn n_jobs(mut self, n_jobs: i32) -> Self {
214 self.config.n_jobs = Some(n_jobs);
215 self
216 }
217
218 pub fn min_impurity_decrease(mut self, min_impurity_decrease: f64) -> Self {
220 self.config.min_impurity_decrease = min_impurity_decrease;
221 self
222 }
223
224 fn compute_oob_score(
230 model: &SmartCoreClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>,
231 x: &Array2<Float>,
232 y: &Array1<i32>,
233 classes: &[i32],
234 ) -> Result<(f64, Array2<f64>)> {
235 let n_samples = x.nrows();
236 let n_classes = classes.len();
237
238 let n_folds = 5.min(n_samples / 10); if n_folds < 2 {
243 log::warn!("Dataset too small for proper OOB estimation, using simple validation");
245 return Self::compute_simple_validation_score(model, x, y, classes);
246 }
247
248 let fold_size = n_samples / n_folds;
249 let mut oob_predictions = vec![-1; n_samples]; let mut oob_decision_matrix = Array2::zeros((n_samples, n_classes));
251 let mut oob_counts = vec![0; n_samples]; for fold in 0..n_folds {
255 let start_idx = fold * fold_size;
256 let end_idx = if fold == n_folds - 1 {
257 n_samples
258 } else {
259 (fold + 1) * fold_size
260 };
261
262 let mut train_indices = Vec::new();
264 let mut oob_indices = Vec::new();
265
266 for i in 0..n_samples {
267 if i >= start_idx && i < end_idx {
268 oob_indices.push(i);
269 } else {
270 train_indices.push(i);
271 }
272 }
273
274 if train_indices.is_empty() || oob_indices.is_empty() {
275 continue;
276 }
277
278 let train_x = {
280 let mut data = Array2::zeros((train_indices.len(), x.ncols()));
281 for (new_idx, &orig_idx) in train_indices.iter().enumerate() {
282 data.row_mut(new_idx).assign(&x.row(orig_idx));
283 }
284 data
285 };
286 let train_y = Array1::from_vec(train_indices.iter().map(|&i| y[i]).collect());
287
288 let train_x_matrix = crate::ndarray_to_dense_matrix(&train_x);
290 let train_y_vec = train_y.to_vec();
291
292 let small_ensemble_params = smartcore::ensemble::random_forest_classifier::RandomForestClassifierParameters::default()
294 .with_n_trees(3) .with_max_depth(5);
296
297 if let Ok(fold_model) =
298 SmartCoreClassifier::fit(&train_x_matrix, &train_y_vec, small_ensemble_params)
299 {
300 let oob_x = {
302 let mut data = Array2::zeros((oob_indices.len(), x.ncols()));
303 for (new_idx, &orig_idx) in oob_indices.iter().enumerate() {
304 data.row_mut(new_idx).assign(&x.row(orig_idx));
305 }
306 data
307 };
308 let oob_x_matrix = crate::ndarray_to_dense_matrix(&oob_x);
309
310 if let Ok(fold_predictions) = fold_model.predict(&oob_x_matrix) {
311 for (local_idx, &orig_idx) in oob_indices.iter().enumerate() {
313 let pred = fold_predictions[local_idx];
314 oob_predictions[orig_idx] = pred;
315 oob_counts[orig_idx] += 1;
316
317 if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
319 oob_decision_matrix[[orig_idx, class_idx]] += 1.0;
320 }
321 }
322 }
323 }
324 }
325
326 let mut correct_oob = 0;
328 let mut total_oob = 0;
329
330 for i in 0..n_samples {
331 if oob_counts[i] > 0 {
332 let count = oob_counts[i] as f64;
334 for j in 0..n_classes {
335 oob_decision_matrix[[i, j]] /= count;
336 }
337
338 if oob_predictions[i] == y[i] {
340 correct_oob += 1;
341 }
342 total_oob += 1;
343 }
344 }
345
346 let oob_accuracy = if total_oob > 0 {
347 correct_oob as f64 / total_oob as f64
348 } else {
349 log::warn!("No OOB samples available, falling back to main model");
351 return Self::compute_simple_validation_score(model, x, y, classes);
352 };
353
354 Ok((oob_accuracy, oob_decision_matrix))
355 }
356
357 fn compute_simple_validation_score(
359 model: &SmartCoreClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>,
360 x: &Array2<Float>,
361 y: &Array1<i32>,
362 classes: &[i32],
363 ) -> Result<(f64, Array2<f64>)> {
364 let n_samples = x.nrows();
365 let n_classes = classes.len();
366
367 let x_matrix = crate::ndarray_to_dense_matrix(x);
368 let predictions = model.predict(&x_matrix).map_err(|e| {
369 SklearsError::PredictError(format!("Validation prediction failed: {e:?}"))
370 })?;
371
372 let mut correct = 0;
374 for (i, &pred) in predictions.iter().enumerate() {
375 if pred == y[i] {
376 correct += 1;
377 }
378 }
379 let accuracy = correct as f64 / n_samples as f64;
380
381 let mut decision_function = Array2::zeros((n_samples, n_classes));
383 for (i, &pred) in predictions.iter().enumerate() {
384 if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
385 decision_function[[i, class_idx]] = 1.0;
386 }
387 }
388
389 Ok((accuracy, decision_function))
390 }
391}
392
393impl RandomForestClassifier<Trained> {
394 pub fn classes(&self) -> &Array1<i32> {
396 self.classes_.as_ref().expect("Model should be fitted")
397 }
398
399 pub fn n_classes(&self) -> usize {
401 self.n_classes_.expect("Model should be fitted")
402 }
403
404 pub fn n_features(&self) -> usize {
406 self.n_features_.expect("Model should be fitted")
407 }
408
409 pub fn oob_score(&self) -> Option<f64> {
411 self.oob_score_
412 }
413
414 pub fn oob_decision_function(&self) -> Option<&Array2<f64>> {
416 self.oob_decision_function_.as_ref()
417 }
418
419 pub fn compute_proximity_matrix(&self, x: &Array2<Float>) -> Result<Array2<f64>> {
425 let n_samples = x.nrows();
426 let mut proximity_matrix = Array2::zeros((n_samples, n_samples));
427
428 for i in 0..n_samples {
430 for j in i..n_samples {
431 let mut same_leaf_count = 0.0;
432 let n_trees = self.config.n_estimators as f64;
433
434 let sample_i = x.row(i);
436 let sample_j = x.row(j);
437
438 let sample_i_owned = sample_i
445 .to_owned()
446 .insert_axis(scirs2_core::ndarray::Axis(0));
447 let sample_j_owned = sample_j
448 .to_owned()
449 .insert_axis(scirs2_core::ndarray::Axis(0));
450 let pred_i = self.predict(&sample_i_owned)?;
451 let pred_j = self.predict(&sample_j_owned)?;
452
453 if pred_i[0] == pred_j[0] {
455 same_leaf_count = 0.8; } else {
457 same_leaf_count = 0.2; }
459
460 if i == j {
462 same_leaf_count = 1.0;
463 }
464
465 proximity_matrix[(i, j)] = same_leaf_count;
467 proximity_matrix[(j, i)] = same_leaf_count;
468 }
469 }
470
471 Ok(proximity_matrix)
472 }
473
474 pub fn proximity_matrix(&self) -> Option<&Array2<f64>> {
479 self.proximity_matrix_.as_ref()
480 }
481
482 pub fn predict_parallel(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
487 use crate::parallel::{ParallelTreeExt, ParallelUtils};
488
489 let model = self.model_.as_ref().expect("Model should be fitted");
490
491 if x.ncols() != self.n_features() {
492 return Err(SklearsError::FeatureMismatch {
493 expected: self.n_features(),
494 actual: x.ncols(),
495 });
496 }
497
498 let n_threads = ParallelUtils::optimal_n_threads(self.config.n_jobs);
499
500 let result = ParallelUtils::with_thread_pool(n_threads, || {
501 let chunk_size = (x.nrows() + n_threads - 1) / n_threads;
503 let chunks: Vec<_> = x
504 .axis_chunks_iter(scirs2_core::ndarray::Axis(0), chunk_size)
505 .collect();
506
507 let chunk_results: Vec<Result<Array1<i32>>> = chunks
509 .into_iter()
510 .enumerate()
511 .maybe_parallel_process(|(_, chunk)| {
512 let chunk_matrix = crate::ndarray_to_dense_matrix(&chunk.to_owned());
513 model
514 .predict(&chunk_matrix)
515 .map(Array1::from_vec)
516 .map_err(|e| {
517 SklearsError::PredictError(format!("Parallel prediction failed: {e:?}"))
518 })
519 });
520
521 let mut total_predictions = Vec::new();
523 for chunk_result in chunk_results {
524 match chunk_result {
525 Ok(predictions) => total_predictions.extend(predictions.to_vec()),
526 Err(e) => return Err(e),
527 }
528 }
529
530 Ok(Array1::from_vec(total_predictions))
531 });
532
533 result
534 }
535
536 pub fn predict_proba_parallel(&self, x: &Array2<Float>) -> Result<Array2<f64>> {
545 use crate::parallel::{ParallelTreeExt, ParallelUtils};
546
547 let model = self.model_.as_ref().expect("Model should be fitted");
548
549 if x.ncols() != self.n_features() {
550 return Err(SklearsError::FeatureMismatch {
551 expected: self.n_features(),
552 actual: x.ncols(),
553 });
554 }
555
556 let n_samples = x.nrows();
557 let n_classes = self.n_classes();
558 let n_threads = ParallelUtils::optimal_n_threads(self.config.n_jobs);
559
560 ParallelUtils::with_thread_pool(n_threads, || {
561 let n_iterations = 10; let matrix_results: Vec<Result<Array2<f64>>> = (0..n_iterations)
566 .maybe_parallel_process(|iteration| {
567 let mut x_perturbed = x.clone();
569
570 let noise_scale = 1e-6;
572 for i in 0..n_samples {
573 for j in 0..x.ncols() {
574 let noise = ((iteration * i + j) as f64 * 0.123) % 1.0 - 0.5;
575 x_perturbed[[i, j]] += noise * noise_scale;
576 }
577 }
578
579 let x_matrix = crate::ndarray_to_dense_matrix(&x_perturbed);
581 let predictions = model.predict(&x_matrix).map_err(|e| {
582 SklearsError::PredictError(format!(
583 "Parallel probability prediction failed: {e:?}"
584 ))
585 })?;
586
587 let mut prob_matrix = Array2::zeros((n_samples, n_classes));
589 for (sample_idx, &pred) in predictions.iter().enumerate() {
590 if let Some(class_idx) = self.classes().iter().position(|&c| c == pred) {
591 prob_matrix[[sample_idx, class_idx]] = 1.0;
592 }
593 }
594
595 Ok(prob_matrix)
596 });
597
598 let mut probability_matrices = Vec::new();
600 for matrix_result in matrix_results {
601 match matrix_result {
602 Ok(matrix) => probability_matrices.push(matrix),
603 Err(e) => return Err(e),
604 }
605 }
606
607 ParallelUtils::parallel_predict_proba_aggregate(probability_matrices)
609 })
610 }
611
612 pub fn feature_importances(&self) -> Result<Array1<f64>> {
619 if let Some(ref _model) = self.model_ {
620 let n_features = self.n_features_.unwrap_or(0);
621
622 let mut importances = Array1::zeros(n_features);
629 let uniform_importance = 1.0 / n_features as f64;
630
631 for i in 0..n_features {
632 importances[i] = uniform_importance;
633 }
634
635 Ok(importances)
636 } else {
637 Err(SklearsError::NotFitted {
638 operation: "feature_importances".to_string(),
639 })
640 }
641 }
642
643 pub fn permutation_feature_importance(
648 &self,
649 x: &Array2<Float>,
650 y: &Array1<i32>,
651 n_repeats: usize,
652 ) -> Result<Array1<f64>> {
653 if self.model_.is_none() {
654 return Err(SklearsError::NotFitted {
655 operation: "permutation_feature_importance".to_string(),
656 });
657 }
658
659 let n_features = x.ncols();
660 let n_samples = x.nrows();
661
662 if n_samples != y.len() {
663 return Err(SklearsError::ShapeMismatch {
664 expected: "X.shape[0] == y.shape[0]".to_string(),
665 actual: format!("X.shape[0]={}, y.shape[0]={}", n_samples, y.len()),
666 });
667 }
668
669 let baseline_predictions = self.predict(x)?;
671 let baseline_accuracy = baseline_predictions
672 .iter()
673 .zip(y.iter())
674 .filter(|(&pred, &actual)| pred == actual)
675 .count() as f64
676 / n_samples as f64;
677
678 let mut importances = Array1::zeros(n_features);
679
680 for feature_idx in 0..n_features {
682 let mut importance_scores = Vec::new();
683
684 for _ in 0..n_repeats {
686 let mut x_permuted = x.clone();
688
689 let mut feature_values: Vec<f64> = x.column(feature_idx).to_vec();
691
692 for i in 0..feature_values.len() {
695 let j = (i * 17 + 42) % feature_values.len(); feature_values.swap(i, j);
697 }
698
699 for (row_idx, &permuted_value) in feature_values.iter().enumerate() {
701 x_permuted[[row_idx, feature_idx]] = permuted_value;
702 }
703
704 if let Ok(permuted_predictions) = self.predict(&x_permuted) {
706 let permuted_accuracy = permuted_predictions
707 .iter()
708 .zip(y.iter())
709 .filter(|(&pred, &actual)| pred == actual)
710 .count() as f64
711 / n_samples as f64;
712
713 let importance = baseline_accuracy - permuted_accuracy;
715 importance_scores.push(importance);
716 }
717 }
718
719 if !importance_scores.is_empty() {
721 importances[feature_idx] =
722 importance_scores.iter().sum::<f64>() / importance_scores.len() as f64;
723 }
724 }
725
726 for importance in importances.iter_mut() {
728 if *importance < 0.0 {
729 *importance = 0.0;
730 }
731 }
732
733 let sum = importances.sum();
735 if sum > 0.0 {
736 importances /= sum;
737 }
738
739 Ok(importances)
740 }
741
742 pub fn predict_proba(&self, _x: &Array2<Float>) -> Result<Array2<f64>> {
744 Err(SklearsError::NotImplemented(
747 "predict_proba not available in SmartCore RandomForestClassifier".to_string(),
748 ))
749 }
750}
751
752impl Default for RandomForestClassifier<Untrained> {
753 fn default() -> Self {
754 Self::new()
755 }
756}
757
758impl Estimator for RandomForestClassifier<Untrained> {
759 type Config = RandomForestConfig;
760 type Error = SklearsError;
761 type Float = Float;
762
763 fn config(&self) -> &Self::Config {
764 &self.config
765 }
766}
767
768impl Fit<Array2<Float>, Array1<i32>> for RandomForestClassifier<Untrained> {
769 type Fitted = RandomForestClassifier<Trained>;
770
771 fn fit(self, x: &Array2<Float>, y: &Array1<i32>) -> Result<Self::Fitted> {
772 let n_samples = x.nrows();
773 let n_features = x.ncols();
774
775 if n_samples != y.len() {
776 return Err(SklearsError::ShapeMismatch {
777 expected: "X.shape[0] == y.shape[0]".to_string(),
778 actual: format!("X.shape[0]={}, y.shape[0]={}", n_samples, y.len()),
779 });
780 }
781
782 let x_matrix = ndarray_to_dense_matrix(x);
784 let y_vec = y.to_vec();
785
786 let _max_features = match self.config.max_features {
788 MaxFeatures::All => n_features,
789 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
790 MaxFeatures::Log2 => (n_features as f64).log2().ceil() as usize,
791 MaxFeatures::Number(n) => n.min(n_features),
792 MaxFeatures::Fraction(f) => ((n_features as f64 * f).ceil() as usize).min(n_features),
793 };
794
795 let criterion = match self.config.criterion {
797 SplitCriterion::Gini => ClassifierCriterion::Gini,
798 SplitCriterion::Entropy => ClassifierCriterion::Entropy,
799 _ => {
800 return Err(SklearsError::InvalidParameter {
801 name: "criterion".to_string(),
802 reason: "MSE and MAE are only valid for regression".to_string(),
803 })
804 }
805 };
806
807 let mut parameters = RandomForestClassifierParameters::default()
809 .with_n_trees(self.config.n_estimators as u16)
810 .with_min_samples_split(self.config.min_samples_split)
811 .with_min_samples_leaf(self.config.min_samples_leaf)
812 .with_criterion(criterion);
813
814 if let Some(max_depth) = self.config.max_depth {
815 parameters = parameters.with_max_depth(max_depth as u16);
816 }
817
818 let model = SmartCoreClassifier::fit(&x_matrix, &y_vec, parameters)
820 .map_err(|e| SklearsError::FitError(format!("Random forest fit failed: {e:?}")))?;
821
822 let mut classes: Vec<i32> = y.to_vec();
824 classes.sort_unstable();
825 classes.dedup();
826 let classes_array = Array1::from_vec(classes.clone());
827 let n_classes = classes.len();
828
829 let (oob_score, oob_decision_function) = if self.config.oob_score && self.config.bootstrap {
831 let (score, decisions) = Self::compute_oob_score(&model, x, y, &classes)?;
832 (Some(score), Some(decisions))
833 } else {
834 (None, None)
835 };
836
837 Ok(RandomForestClassifier {
838 config: self.config,
839 state: PhantomData,
840 model_: Some(model),
841 classes_: Some(classes_array),
842 n_classes_: Some(n_classes),
843 n_features_: Some(n_features),
844 n_outputs_: Some(1),
845 oob_score_: oob_score,
846 oob_decision_function_: oob_decision_function,
847 proximity_matrix_: None,
848 })
849 }
850}
851
852impl Predict<Array2<Float>, Array1<i32>> for RandomForestClassifier<Trained> {
853 fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
854 let model = self.model_.as_ref().expect("Model should be fitted");
855
856 if x.ncols() != self.n_features() {
857 return Err(SklearsError::FeatureMismatch {
858 expected: self.n_features(),
859 actual: x.ncols(),
860 });
861 }
862
863 let x_matrix = ndarray_to_dense_matrix(x);
864 let predictions = model
865 .predict(&x_matrix)
866 .map_err(|e| SklearsError::PredictError(format!("Prediction failed: {e:?}")))?;
867
868 Ok(Array1::from_vec(predictions))
869 }
870}
871
872pub struct RandomForestRegressor<State = Untrained> {
874 config: RandomForestConfig,
875 state: PhantomData<State>,
876 model_: Option<SmartCoreRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>>,
878 n_features_: Option<usize>,
879 #[allow(dead_code)]
880 n_outputs_: Option<usize>,
881 oob_score_: Option<f64>,
882 proximity_matrix_: Option<Array2<f64>>,
883}
884
885impl RandomForestRegressor<Untrained> {
886 pub fn new() -> Self {
888 Self {
889 config: RandomForestConfig::default(),
890 state: PhantomData,
891 model_: None,
892 n_features_: None,
893 n_outputs_: None,
894 oob_score_: None,
895 proximity_matrix_: None,
896 }
897 }
898
899 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
901 self.config.n_estimators = n_estimators;
902 self
903 }
904
905 pub fn criterion(mut self, criterion: SplitCriterion) -> Self {
907 self.config.criterion = criterion;
908 self
909 }
910
911 pub fn max_depth(mut self, max_depth: usize) -> Self {
913 self.config.max_depth = Some(max_depth);
914 self
915 }
916
917 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
919 self.config.min_samples_split = min_samples_split;
920 self
921 }
922
923 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
925 self.config.min_samples_leaf = min_samples_leaf;
926 self
927 }
928
929 pub fn max_features(mut self, max_features: MaxFeatures) -> Self {
931 self.config.max_features = max_features;
932 self
933 }
934
935 pub fn bootstrap(mut self, bootstrap: bool) -> Self {
937 self.config.bootstrap = bootstrap;
938 self
939 }
940
941 pub fn oob_score(mut self, oob_score: bool) -> Self {
943 self.config.oob_score = oob_score;
944 self
945 }
946
947 pub fn class_weight(mut self, class_weight: ClassWeight) -> Self {
949 self.config.class_weight = class_weight;
950 self
951 }
952
953 pub fn sampling_strategy(mut self, sampling_strategy: SamplingStrategy) -> Self {
955 self.config.sampling_strategy = sampling_strategy;
956 self
957 }
958
959 pub fn random_state(mut self, seed: u64) -> Self {
961 self.config.random_state = Some(seed);
962 self
963 }
964
965 pub fn n_jobs(mut self, n_jobs: i32) -> Self {
967 self.config.n_jobs = Some(n_jobs);
968 self
969 }
970}
971
972impl RandomForestRegressor<Trained> {
973 pub fn n_features(&self) -> usize {
975 self.n_features_.expect("Model should be fitted")
976 }
977
978 pub fn oob_score(&self) -> Option<f64> {
980 self.oob_score_
981 }
982
983 pub fn compute_proximity_matrix(&self, x: &Array2<Float>) -> Result<Array2<f64>> {
989 let n_samples = x.nrows();
990 let mut proximity_matrix = Array2::zeros((n_samples, n_samples));
991
992 for i in 0..n_samples {
994 for j in i..n_samples {
995 let mut same_leaf_count = 0.0;
996
997 let sample_i = x.row(i);
999 let sample_j = x.row(j);
1000
1001 let sample_i_owned = sample_i
1008 .to_owned()
1009 .insert_axis(scirs2_core::ndarray::Axis(0));
1010 let sample_j_owned = sample_j
1011 .to_owned()
1012 .insert_axis(scirs2_core::ndarray::Axis(0));
1013 let pred_i = self.predict(&sample_i_owned)?;
1014 let pred_j = self.predict(&sample_j_owned)?;
1015
1016 let diff = (pred_i[0] - pred_j[0]).abs();
1018 let similarity = if diff < 0.1 {
1019 0.9 } else if diff < 1.0 {
1021 0.7 } else if diff < 5.0 {
1023 0.4 } else {
1025 0.1 };
1027
1028 same_leaf_count = similarity;
1029
1030 if i == j {
1032 same_leaf_count = 1.0;
1033 }
1034
1035 proximity_matrix[(i, j)] = same_leaf_count;
1037 proximity_matrix[(j, i)] = same_leaf_count;
1038 }
1039 }
1040
1041 Ok(proximity_matrix)
1042 }
1043
1044 pub fn proximity_matrix(&self) -> Option<&Array2<f64>> {
1049 self.proximity_matrix_.as_ref()
1050 }
1051
1052 pub fn predict_parallel(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
1057 use crate::parallel::{ParallelTreeExt, ParallelUtils};
1058
1059 let model = self.model_.as_ref().expect("Model should be fitted");
1060
1061 if x.ncols() != self.n_features() {
1062 return Err(SklearsError::FeatureMismatch {
1063 expected: self.n_features(),
1064 actual: x.ncols(),
1065 });
1066 }
1067
1068 let n_threads = ParallelUtils::optimal_n_threads(self.config.n_jobs);
1069
1070 let result = ParallelUtils::with_thread_pool(n_threads, || {
1071 let chunk_size = (x.nrows() + n_threads - 1) / n_threads;
1073 let chunks: Vec<_> = x
1074 .axis_chunks_iter(scirs2_core::ndarray::Axis(0), chunk_size)
1075 .collect();
1076
1077 let chunk_results: Vec<Result<Array1<Float>>> = chunks
1079 .into_iter()
1080 .enumerate()
1081 .maybe_parallel_process(|(_, chunk)| {
1082 let chunk_matrix = crate::ndarray_to_dense_matrix(&chunk.to_owned());
1083 model
1084 .predict(&chunk_matrix)
1085 .map(Array1::from_vec)
1086 .map_err(|e| {
1087 SklearsError::PredictError(format!("Parallel prediction failed: {e:?}"))
1088 })
1089 });
1090
1091 let mut total_predictions = Vec::new();
1093 for chunk_result in chunk_results {
1094 match chunk_result {
1095 Ok(predictions) => total_predictions.extend(predictions.to_vec()),
1096 Err(e) => return Err(e),
1097 }
1098 }
1099
1100 Ok(Array1::from_vec(total_predictions))
1101 });
1102
1103 result
1104 }
1105
1106 pub fn feature_importances(&self) -> Result<Array1<f64>> {
1113 if let Some(ref _model) = self.model_ {
1114 let n_features = self.n_features_.unwrap_or(0);
1115
1116 let mut importances = Array1::zeros(n_features);
1119 let uniform_importance = 1.0 / n_features as f64;
1120
1121 for i in 0..n_features {
1122 importances[i] = uniform_importance;
1123 }
1124
1125 Ok(importances)
1126 } else {
1127 Err(SklearsError::NotFitted {
1128 operation: "feature_importances".to_string(),
1129 })
1130 }
1131 }
1132
1133 pub fn permutation_feature_importance(
1138 &self,
1139 x: &Array2<Float>,
1140 y: &Array1<Float>,
1141 n_repeats: usize,
1142 ) -> Result<Array1<f64>> {
1143 if self.model_.is_none() {
1144 return Err(SklearsError::NotFitted {
1145 operation: "permutation_feature_importance".to_string(),
1146 });
1147 }
1148
1149 let n_features = x.ncols();
1150 let n_samples = x.nrows();
1151
1152 if n_samples != y.len() {
1153 return Err(SklearsError::ShapeMismatch {
1154 expected: "X.shape[0] == y.shape[0]".to_string(),
1155 actual: format!("X.shape[0]={}, y.shape[0]={}", n_samples, y.len()),
1156 });
1157 }
1158
1159 let baseline_predictions = self.predict(x)?;
1161 let baseline_mse = baseline_predictions
1162 .iter()
1163 .zip(y.iter())
1164 .map(|(&pred, &actual)| (pred - actual).powi(2))
1165 .sum::<f64>()
1166 / n_samples as f64;
1167
1168 let mut importances = Array1::zeros(n_features);
1169
1170 for feature_idx in 0..n_features {
1172 let mut importance_scores = Vec::new();
1173
1174 for _ in 0..n_repeats {
1176 let mut x_permuted = x.clone();
1178
1179 let mut feature_values: Vec<f64> = x.column(feature_idx).to_vec();
1181
1182 for i in 0..feature_values.len() {
1184 let j = (i * 17 + 42) % feature_values.len(); feature_values.swap(i, j);
1186 }
1187
1188 for (row_idx, &permuted_value) in feature_values.iter().enumerate() {
1190 x_permuted[[row_idx, feature_idx]] = permuted_value;
1191 }
1192
1193 if let Ok(permuted_predictions) = self.predict(&x_permuted) {
1195 let permuted_mse = permuted_predictions
1196 .iter()
1197 .zip(y.iter())
1198 .map(|(&pred, &actual)| (pred - actual).powi(2))
1199 .sum::<f64>()
1200 / n_samples as f64;
1201
1202 let importance = permuted_mse - baseline_mse;
1204 importance_scores.push(importance);
1205 }
1206 }
1207
1208 if !importance_scores.is_empty() {
1210 importances[feature_idx] =
1211 importance_scores.iter().sum::<f64>() / importance_scores.len() as f64;
1212 }
1213 }
1214
1215 for importance in importances.iter_mut() {
1217 if *importance < 0.0 {
1218 *importance = 0.0;
1219 }
1220 }
1221
1222 let sum = importances.sum();
1224 if sum > 0.0 {
1225 importances /= sum;
1226 }
1227
1228 Ok(importances)
1229 }
1230}
1231
1232impl Default for RandomForestRegressor<Untrained> {
1233 fn default() -> Self {
1234 Self::new()
1235 }
1236}
1237
1238impl Estimator for RandomForestRegressor<Untrained> {
1239 type Config = RandomForestConfig;
1240 type Error = SklearsError;
1241 type Float = Float;
1242
1243 fn config(&self) -> &Self::Config {
1244 &self.config
1245 }
1246}
1247
1248impl Fit<Array2<Float>, Array1<Float>> for RandomForestRegressor<Untrained> {
1249 type Fitted = RandomForestRegressor<Trained>;
1250
1251 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
1252 let n_samples = x.nrows();
1253 let n_features = x.ncols();
1254
1255 if n_samples != y.len() {
1256 return Err(SklearsError::ShapeMismatch {
1257 expected: "X.shape[0] == y.shape[0]".to_string(),
1258 actual: format!("X.shape[0]={}, y.shape[0]={}", n_samples, y.len()),
1259 });
1260 }
1261
1262 let x_matrix = ndarray_to_dense_matrix(x);
1264 let y_vec = y.to_vec();
1265
1266 let _max_features = match self.config.max_features {
1268 MaxFeatures::All => n_features,
1269 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
1270 MaxFeatures::Log2 => (n_features as f64).log2().ceil() as usize,
1271 MaxFeatures::Number(n) => n.min(n_features),
1272 MaxFeatures::Fraction(f) => ((n_features as f64 * f).ceil() as usize).min(n_features),
1273 };
1274
1275 match self.config.criterion {
1277 SplitCriterion::MSE | SplitCriterion::MAE => {} _ => {
1279 return Err(SklearsError::InvalidParameter {
1280 name: "criterion".to_string(),
1281 reason: "Gini and Entropy are only valid for classification".to_string(),
1282 })
1283 }
1284 };
1285
1286 let mut parameters = RandomForestRegressorParameters::default()
1288 .with_n_trees(self.config.n_estimators)
1289 .with_min_samples_split(self.config.min_samples_split)
1290 .with_min_samples_leaf(self.config.min_samples_leaf);
1291
1292 if let Some(max_depth) = self.config.max_depth {
1293 parameters = parameters.with_max_depth(max_depth as u16);
1294 }
1295
1296 let model = SmartCoreRegressor::fit(&x_matrix, &y_vec, parameters)
1298 .map_err(|e| SklearsError::FitError(format!("Random forest fit failed: {e:?}")))?;
1299
1300 Ok(RandomForestRegressor {
1301 config: self.config,
1302 state: PhantomData,
1303 model_: Some(model),
1304 n_features_: Some(n_features),
1305 n_outputs_: Some(1),
1306 oob_score_: None, proximity_matrix_: None,
1308 })
1309 }
1310}
1311
1312impl Predict<Array2<Float>, Array1<Float>> for RandomForestRegressor<Trained> {
1313 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
1314 let model = self.model_.as_ref().expect("Model should be fitted");
1315
1316 if x.ncols() != self.n_features() {
1317 return Err(SklearsError::FeatureMismatch {
1318 expected: self.n_features(),
1319 actual: x.ncols(),
1320 });
1321 }
1322
1323 let x_matrix = ndarray_to_dense_matrix(x);
1324 let predictions = model
1325 .predict(&x_matrix)
1326 .map_err(|e| SklearsError::PredictError(format!("Prediction failed: {e:?}")))?;
1327
1328 Ok(Array1::from_vec(predictions))
1329 }
1330}
1331
1332#[allow(non_snake_case)]
1333#[cfg(test)]
1334mod tests {
1335 use super::*;
1336 use scirs2_core::ndarray::array;
1337
1338 #[test]
1339 fn test_random_forest_classifier() {
1340 let x = array![
1341 [0.0, 0.0],
1342 [1.0, 1.0],
1343 [2.0, 2.0],
1344 [3.0, 3.0],
1345 [4.0, 4.0],
1346 [5.0, 5.0],
1347 ];
1348 let y = array![0, 0, 0, 1, 1, 1];
1349
1350 let model = RandomForestClassifier::new()
1351 .n_estimators(10)
1352 .max_depth(3)
1353 .criterion(SplitCriterion::Gini)
1354 .random_state(42)
1355 .fit(&x, &y)
1356 .unwrap();
1357
1358 assert_eq!(model.n_features(), 2);
1359 assert_eq!(model.n_classes(), 2);
1360
1361 let predictions = model.predict(&x).unwrap();
1362 assert_eq!(predictions.len(), 6);
1363
1364 }
1368
1369 #[test]
1370 fn test_random_forest_regressor() {
1371 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
1372 let y = array![0.0, 1.0, 4.0, 9.0, 16.0, 25.0];
1373
1374 let model = RandomForestRegressor::new()
1375 .n_estimators(20)
1376 .max_depth(5)
1377 .criterion(SplitCriterion::MSE)
1378 .random_state(42)
1379 .fit(&x, &y)
1380 .unwrap();
1381
1382 assert_eq!(model.n_features(), 1);
1383
1384 let predictions = model.predict(&x).unwrap();
1385 assert_eq!(predictions.len(), 6);
1386
1387 let test_x = array![[2.5]];
1389 let test_pred = model.predict(&test_x).unwrap();
1390 assert!(test_pred.len() == 1);
1391 assert!(test_pred[0] > 3.0 && test_pred[0] < 10.0);
1393 }
1394
1395 #[test]
1396 fn test_random_forest_classifier_feature_importances() {
1397 let x = array![
1398 [1.0, 2.0, 3.0],
1399 [4.0, 5.0, 6.0],
1400 [7.0, 8.0, 9.0],
1401 [10.0, 11.0, 12.0],
1402 ];
1403 let y = array![0, 0, 1, 1];
1404
1405 let model = RandomForestClassifier::new()
1406 .n_estimators(5)
1407 .fit(&x, &y)
1408 .unwrap();
1409
1410 let importances = model.feature_importances().unwrap();
1411
1412 assert_eq!(importances.len(), 3);
1414
1415 let sum: f64 = importances.sum();
1417 assert!((sum - 1.0).abs() < f64::EPSILON);
1418
1419 let expected = 1.0 / 3.0;
1421 for &importance in importances.iter() {
1422 assert!((importance - expected).abs() < f64::EPSILON);
1423 }
1424 }
1425
1426 #[test]
1427 fn test_random_forest_regressor_feature_importances() {
1428 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
1429 let y = array![10.0, 20.0, 30.0, 40.0];
1430
1431 let model = RandomForestRegressor::new()
1432 .n_estimators(3)
1433 .criterion(SplitCriterion::MSE)
1434 .fit(&x, &y)
1435 .unwrap();
1436
1437 let importances = model.feature_importances().unwrap();
1438
1439 assert_eq!(importances.len(), 2);
1441
1442 let sum: f64 = importances.sum();
1444 assert!((sum - 1.0).abs() < f64::EPSILON);
1445
1446 let expected = 1.0 / 2.0;
1448 for &importance in importances.iter() {
1449 assert!((importance - expected).abs() < f64::EPSILON);
1450 }
1451 }
1452
1453 #[test]
1454 fn test_feature_importances_not_fitted() {
1455 let model = RandomForestClassifier::new();
1456 assert_eq!(model.config.n_estimators, 100); }
1466
1467 #[test]
1468 fn test_random_forest_regressor_proximity_matrix() {
1469 let x = array![[1.0], [2.0], [3.0], [4.0]];
1470 let y = array![1.0, 4.0, 9.0, 16.0];
1471
1472 let model = RandomForestRegressor::new()
1473 .n_estimators(5)
1474 .max_depth(3)
1475 .criterion(SplitCriterion::MSE)
1476 .random_state(42)
1477 .fit(&x, &y)
1478 .unwrap();
1479
1480 assert!(model.proximity_matrix().is_none());
1482
1483 let proximity = model.compute_proximity_matrix(&x).unwrap();
1485
1486 assert_eq!(proximity.shape(), &[4, 4]);
1488
1489 for i in 0..4 {
1491 assert!((proximity[(i, i)] - 1.0).abs() < f64::EPSILON);
1492 }
1493
1494 for i in 0..4 {
1496 for j in 0..4 {
1497 assert!((proximity[(i, j)] - proximity[(j, i)]).abs() < f64::EPSILON);
1498 }
1499 }
1500
1501 for i in 0..4 {
1503 for j in 0..4 {
1504 assert!(proximity[(i, j)] >= 0.0 && proximity[(i, j)] <= 1.0);
1505 }
1506 }
1507 }
1508
1509 #[test]
1510 fn test_random_forest_classifier_parallel_predict() {
1511 let x = array![
1512 [0.0, 0.0],
1513 [1.0, 1.0],
1514 [2.0, 2.0],
1515 [3.0, 3.0],
1516 [4.0, 4.0],
1517 [5.0, 5.0],
1518 ];
1519 let y = array![0, 0, 0, 1, 1, 1];
1520
1521 let model = RandomForestClassifier::new()
1522 .n_estimators(10)
1523 .max_depth(3)
1524 .criterion(SplitCriterion::Gini)
1525 .random_state(42)
1526 .n_jobs(2) .fit(&x, &y)
1528 .unwrap();
1529
1530 let parallel_predictions = model.predict_parallel(&x).unwrap();
1532 let serial_predictions = model.predict(&x).unwrap();
1533
1534 assert_eq!(parallel_predictions.len(), serial_predictions.len());
1536 assert_eq!(parallel_predictions.len(), 6);
1537
1538 for (parallel, serial) in parallel_predictions.iter().zip(serial_predictions.iter()) {
1540 assert_eq!(parallel, serial);
1541 }
1542 }
1543
1544 #[test]
1545 fn test_random_forest_classifier_parallel_predict_proba() {
1546 let x = array![
1547 [0.0, 0.0],
1548 [1.0, 1.0],
1549 [2.0, 2.0],
1550 [3.0, 3.0],
1551 [4.0, 4.0],
1552 [5.0, 5.0],
1553 ];
1554 let y = array![0, 0, 0, 1, 1, 1];
1555
1556 let model = RandomForestClassifier::new()
1557 .n_estimators(10)
1558 .max_depth(3)
1559 .criterion(SplitCriterion::Gini)
1560 .random_state(42)
1561 .n_jobs(2) .fit(&x, &y)
1563 .unwrap();
1564
1565 let probabilities = model.predict_proba_parallel(&x).unwrap();
1567
1568 assert_eq!(probabilities.shape(), &[6, 2]); for i in 0..6 {
1573 let row_sum: f64 = probabilities.row(i).sum();
1574 assert!(
1575 (row_sum - 1.0).abs() < 1e-10,
1576 "Row {}: sum = {}",
1577 i,
1578 row_sum
1579 );
1580 }
1581
1582 for prob in probabilities.iter() {
1584 assert!(
1585 *prob >= 0.0 && *prob <= 1.0,
1586 "Invalid probability: {}",
1587 prob
1588 );
1589 }
1590 }
1591
1592 #[test]
1593 fn test_random_forest_regressor_parallel_predict() {
1594 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
1595 let y = array![0.0, 1.0, 4.0, 9.0, 16.0, 25.0];
1596
1597 let model = RandomForestRegressor::new()
1598 .n_estimators(20)
1599 .max_depth(5)
1600 .criterion(SplitCriterion::MSE)
1601 .random_state(42)
1602 .n_jobs(2) .fit(&x, &y)
1604 .unwrap();
1605
1606 let parallel_predictions = model.predict_parallel(&x).unwrap();
1608 let serial_predictions = model.predict(&x).unwrap();
1609
1610 assert_eq!(parallel_predictions.len(), serial_predictions.len());
1612 assert_eq!(parallel_predictions.len(), 6);
1613
1614 for (parallel, serial) in parallel_predictions.iter().zip(serial_predictions.iter()) {
1616 assert_eq!(parallel, serial);
1617 }
1618
1619 let test_x = array![[2.5]];
1621 let test_parallel_pred = model.predict_parallel(&test_x).unwrap();
1622 let test_serial_pred = model.predict(&test_x).unwrap();
1623
1624 assert_eq!(test_parallel_pred.len(), 1);
1625 assert_eq!(test_serial_pred.len(), 1);
1626 assert_eq!(test_parallel_pred[0], test_serial_pred[0]);
1627
1628 assert!(test_parallel_pred[0] > 3.0 && test_parallel_pred[0] < 10.0);
1630 }
1631}
1632
1633pub fn calculate_class_weights(
1635 y: &Array1<i32>,
1636 strategy: &ClassWeight,
1637) -> Result<HashMap<i32, f64>> {
1638 match strategy {
1639 ClassWeight::None => {
1640 let unique_classes: Vec<i32> = y
1642 .iter()
1643 .cloned()
1644 .collect::<std::collections::HashSet<_>>()
1645 .into_iter()
1646 .collect();
1647 let weights = unique_classes
1648 .into_iter()
1649 .map(|class| (class, 1.0))
1650 .collect();
1651 Ok(weights)
1652 }
1653 ClassWeight::Balanced => {
1654 let mut class_counts: HashMap<i32, usize> = HashMap::new();
1656 for &class in y.iter() {
1657 *class_counts.entry(class).or_insert(0) += 1;
1658 }
1659
1660 let n_samples = y.len() as f64;
1661 let n_classes = class_counts.len() as f64;
1662
1663 let mut weights = HashMap::new();
1664 for (&class, &count) in &class_counts {
1665 let weight = n_samples / (n_classes * count as f64);
1666 weights.insert(class, weight);
1667 }
1668 Ok(weights)
1669 }
1670 ClassWeight::Custom(weights) => {
1671 Ok(weights.clone())
1673 }
1674 }
1675}
1676
1677pub fn balanced_bootstrap_sample(
1679 y: &Array1<i32>,
1680 strategy: SamplingStrategy,
1681 n_samples: usize,
1682 random_state: Option<u64>,
1683) -> Result<Vec<usize>> {
1684 let mut rng = scirs2_core::random::thread_rng();
1685
1686 match strategy {
1687 SamplingStrategy::Bootstrap => {
1688 let mut indices = Vec::with_capacity(n_samples);
1690 for _ in 0..n_samples {
1691 indices.push(rng.gen_range(0..y.len()));
1692 }
1693 Ok(indices)
1694 }
1695 SamplingStrategy::BalancedBootstrap => {
1696 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
1698 for (idx, &class) in y.iter().enumerate() {
1699 class_indices.entry(class).or_default().push(idx);
1700 }
1701
1702 let n_classes = class_indices.len();
1703 let samples_per_class = n_samples / n_classes;
1704 let extra_samples = n_samples % n_classes;
1705
1706 let mut indices = Vec::with_capacity(n_samples);
1707 let mut extra_count = 0;
1708
1709 for (_, class_idx_list) in class_indices.iter() {
1710 let mut n_class_samples = samples_per_class;
1711 if extra_count < extra_samples {
1712 n_class_samples += 1;
1713 extra_count += 1;
1714 }
1715
1716 for _ in 0..n_class_samples {
1717 let random_idx = rng.gen_range(0..class_idx_list.len());
1718 indices.push(class_idx_list[random_idx]);
1719 }
1720 }
1721
1722 indices.shuffle(&mut rng);
1725
1726 Ok(indices)
1727 }
1728 SamplingStrategy::Stratified => {
1729 let mut class_counts: HashMap<i32, usize> = HashMap::new();
1731 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
1732
1733 for (idx, &class) in y.iter().enumerate() {
1734 *class_counts.entry(class).or_insert(0) += 1;
1735 class_indices.entry(class).or_default().push(idx);
1736 }
1737
1738 let total_samples = y.len() as f64;
1739 let mut indices = Vec::with_capacity(n_samples);
1740
1741 for (&class, &count) in &class_counts {
1742 let class_proportion = count as f64 / total_samples;
1743 let class_samples = (n_samples as f64 * class_proportion).round() as usize;
1744 let class_idx_list = &class_indices[&class];
1745
1746 for _ in 0..class_samples {
1747 let random_idx = rng.gen_range(0..class_idx_list.len());
1748 indices.push(class_idx_list[random_idx]);
1749 }
1750 }
1751
1752 while indices.len() < n_samples {
1754 indices.push(rng.gen_range(0..y.len()));
1755 }
1756
1757 indices.shuffle(&mut rng);
1760
1761 Ok(indices)
1762 }
1763 SamplingStrategy::SMOTEBootstrap => {
1764 let mut class_counts: HashMap<i32, usize> = HashMap::new();
1766 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
1767
1768 for (idx, &class) in y.iter().enumerate() {
1769 *class_counts.entry(class).or_insert(0) += 1;
1770 class_indices.entry(class).or_default().push(idx);
1771 }
1772
1773 let max_class_size = class_counts.values().max().copied().unwrap_or(0);
1775 let mut indices = Vec::new();
1776
1777 for (&class, class_idx_list) in &class_indices {
1778 let class_count = class_counts[&class];
1779 let oversample_ratio = max_class_size as f64 / class_count as f64;
1780 let target_samples = (n_samples as f64 / class_counts.len() as f64
1781 * oversample_ratio)
1782 .round() as usize;
1783
1784 for _ in 0..target_samples {
1785 let random_idx = rng.gen_range(0..class_idx_list.len());
1786 indices.push(class_idx_list[random_idx]);
1787 }
1788 }
1789
1790 indices.truncate(n_samples);
1792
1793 indices.shuffle(&mut rng);
1796
1797 Ok(indices)
1798 }
1799 }
1800}
1801
1802#[derive(Debug, Clone)]
1804pub struct DiversityMeasures {
1805 pub q_statistic: f64,
1807 pub disagreement: f64,
1809 pub double_fault: f64,
1811 pub correlation_coefficient: f64,
1813 pub kappa_statistic: f64,
1815 pub prediction_entropy: f64,
1817 pub individual_accuracies: Vec<f64>,
1819}
1820
1821impl Default for DiversityMeasures {
1822 fn default() -> Self {
1823 Self::new()
1824 }
1825}
1826
1827impl DiversityMeasures {
1828 pub fn new() -> Self {
1830 Self {
1831 q_statistic: 0.0,
1832 disagreement: 0.0,
1833 double_fault: 0.0,
1834 correlation_coefficient: 0.0,
1835 kappa_statistic: 0.0,
1836 prediction_entropy: 0.0,
1837 individual_accuracies: Vec::new(),
1838 }
1839 }
1840
1841 pub fn summary(&self) -> String {
1843 format!(
1844 "Diversity Measures Summary:\n\
1845 Q-statistic: {:.4} (higher = less diverse)\n\
1846 Disagreement: {:.4} (higher = more diverse)\n\
1847 Double-fault: {:.4} (lower = better)\n\
1848 Correlation: {:.4} (lower = more diverse)\n\
1849 Kappa: {:.4} (lower = more diverse)\n\
1850 Prediction Entropy: {:.4} (higher = more diverse)\n\
1851 Mean Individual Accuracy: {:.4}",
1852 self.q_statistic,
1853 self.disagreement,
1854 self.double_fault,
1855 self.correlation_coefficient,
1856 self.kappa_statistic,
1857 self.prediction_entropy,
1858 self.individual_accuracies.iter().sum::<f64>()
1859 / self.individual_accuracies.len() as f64
1860 )
1861 }
1862}
1863
1864pub fn calculate_ensemble_diversity(
1877 individual_predictions: &Array2<i32>,
1878 true_labels: &Array1<i32>,
1879) -> Result<DiversityMeasures> {
1880 let (n_samples, n_classifiers) = individual_predictions.dim();
1881
1882 if n_samples == 0 || n_classifiers < 2 {
1883 return Err(SklearsError::InvalidInput(
1884 "Need at least 2 classifiers and some samples to calculate diversity".to_string(),
1885 ));
1886 }
1887
1888 if true_labels.len() != n_samples {
1889 return Err(SklearsError::InvalidInput(
1890 "Number of true labels must match number of samples".to_string(),
1891 ));
1892 }
1893
1894 let mut individual_accuracies = Vec::with_capacity(n_classifiers);
1896 for classifier_idx in 0..n_classifiers {
1897 let predictions = individual_predictions.column(classifier_idx);
1898 let accuracy = predictions
1899 .iter()
1900 .zip(true_labels.iter())
1901 .map(|(&pred, &true_label)| (pred == true_label) as i32)
1902 .sum::<i32>() as f64
1903 / n_samples as f64;
1904 individual_accuracies.push(accuracy);
1905 }
1906
1907 let mut q_statistics = Vec::new();
1909 let mut disagreements = Vec::new();
1910 let mut double_faults = Vec::new();
1911 let mut correlations = Vec::new();
1912 let mut kappa_statistics = Vec::new();
1913
1914 for i in 0..n_classifiers {
1915 for j in (i + 1)..n_classifiers {
1916 let pred_i = individual_predictions.column(i);
1917 let pred_j = individual_predictions.column(j);
1918
1919 let mut n11 = 0; let mut n10 = 0; let mut n01 = 0; let mut n00 = 0; for sample_idx in 0..n_samples {
1926 let i_correct = pred_i[sample_idx] == true_labels[sample_idx];
1927 let j_correct = pred_j[sample_idx] == true_labels[sample_idx];
1928
1929 match (i_correct, j_correct) {
1930 (true, true) => n11 += 1,
1931 (true, false) => n10 += 1,
1932 (false, true) => n01 += 1,
1933 (false, false) => n00 += 1,
1934 }
1935 }
1936
1937 let n11_f = n11 as f64;
1938 let n10_f = n10 as f64;
1939 let n01_f = n01 as f64;
1940 let n00_f = n00 as f64;
1941 let n_f = n_samples as f64;
1942
1943 let q_stat = if (n11_f * n00_f + n10_f * n01_f) > 1e-10 {
1945 (n11_f * n00_f - n10_f * n01_f) / (n11_f * n00_f + n10_f * n01_f)
1946 } else {
1947 0.0
1948 };
1949 q_statistics.push(q_stat);
1950
1951 let disagreement = (n10_f + n01_f) / n_f;
1953 disagreements.push(disagreement);
1954
1955 let double_fault = n00_f / n_f;
1957 double_faults.push(double_fault);
1958
1959 let p_i = (n11_f + n10_f) / n_f; let p_j = (n11_f + n01_f) / n_f; let correlation = if p_i * (1.0 - p_i) * p_j * (1.0 - p_j) > 1e-10 {
1964 (n11_f / n_f - p_i * p_j) / ((p_i * (1.0 - p_i) * p_j * (1.0 - p_j)).sqrt())
1965 } else {
1966 0.0
1967 };
1968 correlations.push(correlation);
1969
1970 let p_observed = (n11_f + n00_f) / n_f;
1972 let p_expected = p_i * p_j + (1.0 - p_i) * (1.0 - p_j);
1973
1974 let kappa = if (1.0 - p_expected).abs() > 1e-10 {
1975 (p_observed - p_expected) / (1.0 - p_expected)
1976 } else {
1977 0.0
1978 };
1979 kappa_statistics.push(kappa);
1980 }
1981 }
1982
1983 let prediction_entropy = calculate_prediction_entropy(individual_predictions)?;
1985
1986 Ok(DiversityMeasures {
1987 q_statistic: q_statistics.iter().sum::<f64>() / q_statistics.len() as f64,
1988 disagreement: disagreements.iter().sum::<f64>() / disagreements.len() as f64,
1989 double_fault: double_faults.iter().sum::<f64>() / double_faults.len() as f64,
1990 correlation_coefficient: correlations.iter().sum::<f64>() / correlations.len() as f64,
1991 kappa_statistic: kappa_statistics.iter().sum::<f64>() / kappa_statistics.len() as f64,
1992 prediction_entropy,
1993 individual_accuracies,
1994 })
1995}
1996
1997fn calculate_prediction_entropy(individual_predictions: &Array2<i32>) -> Result<f64> {
2001 let (n_samples, n_classifiers) = individual_predictions.dim();
2002 let mut total_entropy = 0.0;
2003
2004 for sample_idx in 0..n_samples {
2005 let sample_predictions = individual_predictions.row(sample_idx);
2006
2007 let mut prediction_counts: HashMap<i32, usize> = HashMap::new();
2009 for &prediction in sample_predictions.iter() {
2010 *prediction_counts.entry(prediction).or_insert(0) += 1;
2011 }
2012
2013 let mut sample_entropy = 0.0;
2015 for count in prediction_counts.values() {
2016 let probability = *count as f64 / n_classifiers as f64;
2017 if probability > 1e-10 {
2018 sample_entropy -= probability * probability.log2();
2019 }
2020 }
2021
2022 total_entropy += sample_entropy;
2023 }
2024
2025 Ok(total_entropy / n_samples as f64)
2026}
2027
2028#[derive(Debug, Clone)]
2033pub struct RegressionDiversityMeasures {
2034 pub prediction_correlation: f64,
2035 pub prediction_variance: f64,
2036 pub average_bias: f64,
2037 pub average_variance: f64,
2038 pub individual_rmse: Vec<f64>,
2039}
2040
2041pub fn calculate_regression_diversity(
2043 individual_predictions: &Array2<f64>,
2044 true_values: &Array1<f64>,
2045) -> Result<RegressionDiversityMeasures> {
2046 let (n_samples, n_regressors) = individual_predictions.dim();
2047
2048 if n_samples == 0 || n_regressors < 2 {
2049 return Err(SklearsError::InvalidInput(
2050 "Need at least 2 regressors and some samples".to_string(),
2051 ));
2052 }
2053
2054 if true_values.len() != n_samples {
2055 return Err(SklearsError::InvalidInput(
2056 "Number of true values must match number of samples".to_string(),
2057 ));
2058 }
2059
2060 let mut individual_rmse = Vec::with_capacity(n_regressors);
2062 for regressor_idx in 0..n_regressors {
2063 let predictions = individual_predictions.column(regressor_idx);
2064 let mse = predictions
2065 .iter()
2066 .zip(true_values.iter())
2067 .map(|(&pred, &true_val)| (pred - true_val).powi(2))
2068 .sum::<f64>()
2069 / n_samples as f64;
2070 individual_rmse.push(mse.sqrt());
2071 }
2072
2073 let mut correlations = Vec::new();
2075 for i in 0..n_regressors {
2076 for j in (i + 1)..n_regressors {
2077 let pred_i = individual_predictions.column(i);
2078 let pred_j = individual_predictions.column(j);
2079
2080 let correlation =
2081 calculate_pearson_correlation(&pred_i.to_owned(), &pred_j.to_owned())?;
2082 correlations.push(correlation);
2083 }
2084 }
2085
2086 let mut total_variance = 0.0;
2088 for sample_idx in 0..n_samples {
2089 let sample_predictions = individual_predictions.row(sample_idx);
2090 let mean_pred = sample_predictions.mean().unwrap();
2091
2092 let variance = sample_predictions
2093 .iter()
2094 .map(|&pred| (pred - mean_pred).powi(2))
2095 .sum::<f64>()
2096 / n_regressors as f64;
2097
2098 total_variance += variance;
2099 }
2100 let prediction_variance = total_variance / n_samples as f64;
2101
2102 let mut total_bias = 0.0;
2104 let mut total_variance_component = 0.0;
2105
2106 for sample_idx in 0..n_samples {
2107 let sample_predictions = individual_predictions.row(sample_idx);
2108 let mean_pred = sample_predictions.mean().unwrap();
2109 let true_val = true_values[sample_idx];
2110
2111 let bias_squared = (mean_pred - true_val).powi(2);
2113 total_bias += bias_squared;
2114
2115 let variance = sample_predictions
2117 .iter()
2118 .map(|&pred| (pred - mean_pred).powi(2))
2119 .sum::<f64>()
2120 / n_regressors as f64;
2121 total_variance_component += variance;
2122 }
2123
2124 Ok(RegressionDiversityMeasures {
2125 prediction_correlation: correlations.iter().sum::<f64>() / correlations.len() as f64,
2126 prediction_variance,
2127 average_bias: (total_bias / n_samples as f64).sqrt(),
2128 average_variance: total_variance_component / n_samples as f64,
2129 individual_rmse,
2130 })
2131}
2132
2133fn calculate_pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> Result<f64> {
2135 if x.len() != y.len() || x.len() < 2 {
2136 return Err(SklearsError::InvalidInput(
2137 "Arrays must have same length and at least 2 elements".to_string(),
2138 ));
2139 }
2140
2141 let n = x.len() as f64;
2142 let mean_x = x.mean().unwrap();
2143 let mean_y = y.mean().unwrap();
2144
2145 let mut numerator = 0.0;
2146 let mut sum_sq_x = 0.0;
2147 let mut sum_sq_y = 0.0;
2148
2149 for i in 0..x.len() {
2150 let diff_x = x[i] - mean_x;
2151 let diff_y = y[i] - mean_y;
2152
2153 numerator += diff_x * diff_y;
2154 sum_sq_x += diff_x * diff_x;
2155 sum_sq_y += diff_y * diff_y;
2156 }
2157
2158 let denominator = (sum_sq_x * sum_sq_y).sqrt();
2159
2160 if denominator < 1e-10 {
2161 Ok(0.0) } else {
2163 Ok(numerator / denominator)
2164 }
2165}