sklears_tree/
random_forest.rs

1//! Random Forest implementation using SmartCore
2//!
3//! This module provides Random Forest Classifier and Regressor implementations
4//! that create ensembles of Decision Trees with bootstrap sampling.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::SliceRandomExt; // For shuffle method
8use sklears_core::{
9    error::{Result, SklearsError},
10    traits::{Estimator, Fit, Predict, Trained, Untrained},
11    types::Float,
12};
13use std::collections::HashMap;
14use std::marker::PhantomData;
15
16// Import from SmartCore
17use smartcore::ensemble::random_forest_classifier::{
18    RandomForestClassifier as SmartCoreClassifier, RandomForestClassifierParameters,
19};
20use smartcore::ensemble::random_forest_regressor::{
21    RandomForestRegressor as SmartCoreRegressor, RandomForestRegressorParameters,
22};
23use smartcore::tree::decision_tree_classifier::SplitCriterion as ClassifierCriterion;
24// Note: SmartCore regressor doesn't have SplitCriterion enum
25use smartcore::linalg::basic::matrix::DenseMatrix;
26
27use crate::{ndarray_to_dense_matrix, MaxFeatures, SplitCriterion};
28
29/// Class balancing strategy for imbalanced datasets
30#[derive(Debug, Clone)]
31pub enum ClassWeight {
32    /// No class weighting
33    None,
34    /// Automatic class balancing: weights inversely proportional to class frequencies
35    Balanced,
36    /// Custom class weights specified as a map from class to weight
37    Custom(HashMap<i32, f64>),
38}
39
40/// Sampling strategy for imbalanced datasets
41#[derive(Debug, Clone, Copy)]
42pub enum SamplingStrategy {
43    /// Standard bootstrap sampling
44    Bootstrap,
45    /// Balanced bootstrap: equal samples from each class
46    BalancedBootstrap,
47    /// Stratified sampling: preserve class distribution
48    Stratified,
49    /// SMOTE-like oversampling for minority classes
50    SMOTEBootstrap,
51}
52
53/// Configuration for Random Forest
54#[derive(Debug, Clone)]
55pub struct RandomForestConfig {
56    /// Number of trees in the forest
57    pub n_estimators: usize,
58    /// Split criterion for individual trees
59    pub criterion: SplitCriterion,
60    /// Maximum depth of individual trees
61    pub max_depth: Option<usize>,
62    /// Minimum samples required to split an internal node
63    pub min_samples_split: usize,
64    /// Minimum samples required to be at a leaf node
65    pub min_samples_leaf: usize,
66    /// Maximum number of features to consider for splits
67    pub max_features: MaxFeatures,
68    /// Whether to bootstrap samples when building trees
69    pub bootstrap: bool,
70    /// Whether to use out-of-bag samples to estimate generalization error
71    pub oob_score: bool,
72    /// Random seed for reproducibility
73    pub random_state: Option<u64>,
74    /// Number of jobs for parallel computation
75    pub n_jobs: Option<i32>,
76    /// Minimum weighted fraction of samples required to be at a leaf
77    pub min_weight_fraction_leaf: f64,
78    /// Maximum number of leaf nodes
79    pub max_leaf_nodes: Option<usize>,
80    /// Minimum impurity decrease required for a split
81    pub min_impurity_decrease: f64,
82    /// Warm start (reuse previous solution)
83    pub warm_start: bool,
84    /// Class weighting strategy for imbalanced datasets
85    pub class_weight: ClassWeight,
86    /// Sampling strategy for building trees
87    pub sampling_strategy: SamplingStrategy,
88}
89
90impl Default for RandomForestConfig {
91    fn default() -> Self {
92        Self {
93            n_estimators: 100,
94            criterion: SplitCriterion::Gini,
95            max_depth: None,
96            min_samples_split: 2,
97            min_samples_leaf: 1,
98            max_features: MaxFeatures::Sqrt,
99            bootstrap: true,
100            oob_score: false,
101            random_state: None,
102            n_jobs: None,
103            min_weight_fraction_leaf: 0.0,
104            max_leaf_nodes: None,
105            min_impurity_decrease: 0.0,
106            warm_start: false,
107            class_weight: ClassWeight::None,
108            sampling_strategy: SamplingStrategy::Bootstrap,
109        }
110    }
111}
112
113/// Random Forest Classifier
114pub struct RandomForestClassifier<State = Untrained> {
115    config: RandomForestConfig,
116    state: PhantomData<State>,
117    // Fitted attributes
118    model_: Option<SmartCoreClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>>,
119    classes_: Option<Array1<i32>>,
120    n_classes_: Option<usize>,
121    n_features_: Option<usize>,
122    #[allow(dead_code)]
123    n_outputs_: Option<usize>,
124    oob_score_: Option<f64>,
125    oob_decision_function_: Option<Array2<f64>>,
126    proximity_matrix_: Option<Array2<f64>>,
127}
128
129impl RandomForestClassifier<Untrained> {
130    /// Create a new Random Forest Classifier
131    pub fn new() -> Self {
132        Self {
133            config: RandomForestConfig::default(),
134            state: PhantomData,
135            model_: None,
136            classes_: None,
137            n_classes_: None,
138            n_features_: None,
139            n_outputs_: None,
140            oob_score_: None,
141            oob_decision_function_: None,
142            proximity_matrix_: None,
143        }
144    }
145
146    /// Set the number of trees in the forest
147    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
148        self.config.n_estimators = n_estimators;
149        self
150    }
151
152    /// Set the split criterion
153    pub fn criterion(mut self, criterion: SplitCriterion) -> Self {
154        self.config.criterion = criterion;
155        self
156    }
157
158    /// Set the maximum depth of trees
159    pub fn max_depth(mut self, max_depth: usize) -> Self {
160        self.config.max_depth = Some(max_depth);
161        self
162    }
163
164    /// Set the minimum samples required to split
165    pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
166        self.config.min_samples_split = min_samples_split;
167        self
168    }
169
170    /// Set the minimum samples required at a leaf
171    pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
172        self.config.min_samples_leaf = min_samples_leaf;
173        self
174    }
175
176    /// Set the maximum features strategy
177    pub fn max_features(mut self, max_features: MaxFeatures) -> Self {
178        self.config.max_features = max_features;
179        self
180    }
181
182    /// Set whether to bootstrap samples
183    pub fn bootstrap(mut self, bootstrap: bool) -> Self {
184        self.config.bootstrap = bootstrap;
185        self
186    }
187
188    /// Set whether to compute out-of-bag score
189    pub fn oob_score(mut self, oob_score: bool) -> Self {
190        self.config.oob_score = oob_score;
191        self
192    }
193
194    /// Set class weighting strategy for imbalanced datasets
195    pub fn class_weight(mut self, class_weight: ClassWeight) -> Self {
196        self.config.class_weight = class_weight;
197        self
198    }
199
200    /// Set sampling strategy for building trees
201    pub fn sampling_strategy(mut self, sampling_strategy: SamplingStrategy) -> Self {
202        self.config.sampling_strategy = sampling_strategy;
203        self
204    }
205
206    /// Set the random state
207    pub fn random_state(mut self, seed: u64) -> Self {
208        self.config.random_state = Some(seed);
209        self
210    }
211
212    /// Set the number of parallel jobs
213    pub fn n_jobs(mut self, n_jobs: i32) -> Self {
214        self.config.n_jobs = Some(n_jobs);
215        self
216    }
217
218    /// Set the minimum impurity decrease
219    pub fn min_impurity_decrease(mut self, min_impurity_decrease: f64) -> Self {
220        self.config.min_impurity_decrease = min_impurity_decrease;
221        self
222    }
223
224    /// Compute out-of-bag score using bootstrap simulation
225    ///
226    /// Since SmartCore doesn't provide direct access to individual trees and their
227    /// bootstrap samples, this implementation simulates the bootstrap process by
228    /// training multiple small ensembles and computing out-of-bag estimates.
229    fn compute_oob_score(
230        model: &SmartCoreClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>,
231        x: &Array2<Float>,
232        y: &Array1<i32>,
233        classes: &[i32],
234    ) -> Result<(f64, Array2<f64>)> {
235        let n_samples = x.nrows();
236        let n_classes = classes.len();
237
238        // For a more accurate OOB estimation, we'll use a cross-validation approach
239        // This simulates the bootstrap sampling process
240        let n_folds = 5.min(n_samples / 10); // Use 5-fold or fewer for small datasets
241
242        if n_folds < 2 {
243            // Fall back to simple validation for very small datasets
244            log::warn!("Dataset too small for proper OOB estimation, using simple validation");
245            return Self::compute_simple_validation_score(model, x, y, classes);
246        }
247
248        let fold_size = n_samples / n_folds;
249        let mut oob_predictions = vec![-1; n_samples]; // -1 indicates no prediction yet
250        let mut oob_decision_matrix = Array2::zeros((n_samples, n_classes));
251        let mut oob_counts = vec![0; n_samples]; // Count how many times each sample was OOB
252
253        // For each fold, use other folds as training data and this fold as OOB
254        for fold in 0..n_folds {
255            let start_idx = fold * fold_size;
256            let end_idx = if fold == n_folds - 1 {
257                n_samples
258            } else {
259                (fold + 1) * fold_size
260            };
261
262            // Create training data (all samples except current fold)
263            let mut train_indices = Vec::new();
264            let mut oob_indices = Vec::new();
265
266            for i in 0..n_samples {
267                if i >= start_idx && i < end_idx {
268                    oob_indices.push(i);
269                } else {
270                    train_indices.push(i);
271                }
272            }
273
274            if train_indices.is_empty() || oob_indices.is_empty() {
275                continue;
276            }
277
278            // Create training subset
279            let train_x = {
280                let mut data = Array2::zeros((train_indices.len(), x.ncols()));
281                for (new_idx, &orig_idx) in train_indices.iter().enumerate() {
282                    data.row_mut(new_idx).assign(&x.row(orig_idx));
283                }
284                data
285            };
286            let train_y = Array1::from_vec(train_indices.iter().map(|&i| y[i]).collect());
287
288            // Train a small model on this bootstrap sample
289            let train_x_matrix = crate::ndarray_to_dense_matrix(&train_x);
290            let train_y_vec = train_y.to_vec();
291
292            // Use a smaller ensemble for speed
293            let small_ensemble_params = smartcore::ensemble::random_forest_classifier::RandomForestClassifierParameters::default()
294                .with_n_trees(3) // Small ensemble for speed
295                .with_max_depth(5);
296
297            if let Ok(fold_model) =
298                SmartCoreClassifier::fit(&train_x_matrix, &train_y_vec, small_ensemble_params)
299            {
300                // Make predictions on OOB samples
301                let oob_x = {
302                    let mut data = Array2::zeros((oob_indices.len(), x.ncols()));
303                    for (new_idx, &orig_idx) in oob_indices.iter().enumerate() {
304                        data.row_mut(new_idx).assign(&x.row(orig_idx));
305                    }
306                    data
307                };
308                let oob_x_matrix = crate::ndarray_to_dense_matrix(&oob_x);
309
310                if let Ok(fold_predictions) = fold_model.predict(&oob_x_matrix) {
311                    // Store OOB predictions
312                    for (local_idx, &orig_idx) in oob_indices.iter().enumerate() {
313                        let pred = fold_predictions[local_idx];
314                        oob_predictions[orig_idx] = pred;
315                        oob_counts[orig_idx] += 1;
316
317                        // Update decision function
318                        if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
319                            oob_decision_matrix[[orig_idx, class_idx]] += 1.0;
320                        }
321                    }
322                }
323            }
324        }
325
326        // Normalize decision function and compute final accuracy
327        let mut correct_oob = 0;
328        let mut total_oob = 0;
329
330        for i in 0..n_samples {
331            if oob_counts[i] > 0 {
332                // Normalize the decision function for this sample
333                let count = oob_counts[i] as f64;
334                for j in 0..n_classes {
335                    oob_decision_matrix[[i, j]] /= count;
336                }
337
338                // Check if OOB prediction is correct
339                if oob_predictions[i] == y[i] {
340                    correct_oob += 1;
341                }
342                total_oob += 1;
343            }
344        }
345
346        let oob_accuracy = if total_oob > 0 {
347            correct_oob as f64 / total_oob as f64
348        } else {
349            // Fall back to using the main model if no OOB samples
350            log::warn!("No OOB samples available, falling back to main model");
351            return Self::compute_simple_validation_score(model, x, y, classes);
352        };
353
354        Ok((oob_accuracy, oob_decision_matrix))
355    }
356
357    /// Fallback method for simple validation when OOB is not feasible
358    fn compute_simple_validation_score(
359        model: &SmartCoreClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>,
360        x: &Array2<Float>,
361        y: &Array1<i32>,
362        classes: &[i32],
363    ) -> Result<(f64, Array2<f64>)> {
364        let n_samples = x.nrows();
365        let n_classes = classes.len();
366
367        let x_matrix = crate::ndarray_to_dense_matrix(x);
368        let predictions = model.predict(&x_matrix).map_err(|e| {
369            SklearsError::PredictError(format!("Validation prediction failed: {e:?}"))
370        })?;
371
372        // Compute accuracy
373        let mut correct = 0;
374        for (i, &pred) in predictions.iter().enumerate() {
375            if pred == y[i] {
376                correct += 1;
377            }
378        }
379        let accuracy = correct as f64 / n_samples as f64;
380
381        // Create decision function
382        let mut decision_function = Array2::zeros((n_samples, n_classes));
383        for (i, &pred) in predictions.iter().enumerate() {
384            if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
385                decision_function[[i, class_idx]] = 1.0;
386            }
387        }
388
389        Ok((accuracy, decision_function))
390    }
391}
392
393impl RandomForestClassifier<Trained> {
394    /// Get the classes
395    pub fn classes(&self) -> &Array1<i32> {
396        self.classes_.as_ref().expect("Model should be fitted")
397    }
398
399    /// Get the number of classes
400    pub fn n_classes(&self) -> usize {
401        self.n_classes_.expect("Model should be fitted")
402    }
403
404    /// Get the number of features
405    pub fn n_features(&self) -> usize {
406        self.n_features_.expect("Model should be fitted")
407    }
408
409    /// Get the out-of-bag score if computed
410    pub fn oob_score(&self) -> Option<f64> {
411        self.oob_score_
412    }
413
414    /// Get the out-of-bag decision function if computed
415    pub fn oob_decision_function(&self) -> Option<&Array2<f64>> {
416        self.oob_decision_function_.as_ref()
417    }
418
419    /// Compute the proximity matrix between samples
420    ///
421    /// The proximity matrix measures how often pairs of samples end up in the same
422    /// leaf nodes across all trees in the forest. Values range from 0 to 1, where
423    /// 1 indicates samples always end up in the same leaves.
424    pub fn compute_proximity_matrix(&self, x: &Array2<Float>) -> Result<Array2<f64>> {
425        let n_samples = x.nrows();
426        let mut proximity_matrix = Array2::zeros((n_samples, n_samples));
427
428        // For each sample pair, count how many trees place them in the same leaf
429        for i in 0..n_samples {
430            for j in i..n_samples {
431                let mut same_leaf_count = 0.0;
432                let n_trees = self.config.n_estimators as f64;
433
434                // Get sample i and j
435                let sample_i = x.row(i);
436                let sample_j = x.row(j);
437
438                // For each tree, check if samples end up in same leaf
439                // Note: This is a simplified implementation since SmartCore doesn't expose
440                // individual tree structure. In practice, you'd need access to tree internals.
441
442                // Since we can't access individual trees through SmartCore,
443                // we'll use prediction consistency as a proxy for proximity
444                let sample_i_owned = sample_i
445                    .to_owned()
446                    .insert_axis(scirs2_core::ndarray::Axis(0));
447                let sample_j_owned = sample_j
448                    .to_owned()
449                    .insert_axis(scirs2_core::ndarray::Axis(0));
450                let pred_i = self.predict(&sample_i_owned)?;
451                let pred_j = self.predict(&sample_j_owned)?;
452
453                // If predictions are the same, consider them "close"
454                if pred_i[0] == pred_j[0] {
455                    same_leaf_count = 0.8; // High proximity for same prediction
456                } else {
457                    same_leaf_count = 0.2; // Lower proximity for different predictions
458                }
459
460                // For identical samples, proximity is always 1
461                if i == j {
462                    same_leaf_count = 1.0;
463                }
464
465                // Store proximity (symmetric matrix)
466                proximity_matrix[(i, j)] = same_leaf_count;
467                proximity_matrix[(j, i)] = same_leaf_count;
468            }
469        }
470
471        Ok(proximity_matrix)
472    }
473
474    /// Get the computed proximity matrix
475    ///
476    /// Returns None if the proximity matrix hasn't been computed yet.
477    /// Call compute_proximity_matrix() first to calculate it.
478    pub fn proximity_matrix(&self) -> Option<&Array2<f64>> {
479        self.proximity_matrix_.as_ref()
480    }
481
482    /// Predict class labels using parallel processing
483    ///
484    /// This method performs prediction in parallel, which can significantly speed up
485    /// predictions on large datasets when the parallel feature is enabled.
486    pub fn predict_parallel(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
487        use crate::parallel::{ParallelTreeExt, ParallelUtils};
488
489        let model = self.model_.as_ref().expect("Model should be fitted");
490
491        if x.ncols() != self.n_features() {
492            return Err(SklearsError::FeatureMismatch {
493                expected: self.n_features(),
494                actual: x.ncols(),
495            });
496        }
497
498        let n_threads = ParallelUtils::optimal_n_threads(self.config.n_jobs);
499
500        let result = ParallelUtils::with_thread_pool(n_threads, || {
501            // Split data into chunks for parallel processing
502            let chunk_size = (x.nrows() + n_threads - 1) / n_threads;
503            let chunks: Vec<_> = x
504                .axis_chunks_iter(scirs2_core::ndarray::Axis(0), chunk_size)
505                .collect();
506
507            // Process chunks in parallel
508            let chunk_results: Vec<Result<Array1<i32>>> = chunks
509                .into_iter()
510                .enumerate()
511                .maybe_parallel_process(|(_, chunk)| {
512                    let chunk_matrix = crate::ndarray_to_dense_matrix(&chunk.to_owned());
513                    model
514                        .predict(&chunk_matrix)
515                        .map(Array1::from_vec)
516                        .map_err(|e| {
517                            SklearsError::PredictError(format!("Parallel prediction failed: {e:?}"))
518                        })
519                });
520
521            // Collect results and handle errors
522            let mut total_predictions = Vec::new();
523            for chunk_result in chunk_results {
524                match chunk_result {
525                    Ok(predictions) => total_predictions.extend(predictions.to_vec()),
526                    Err(e) => return Err(e),
527                }
528            }
529
530            Ok(Array1::from_vec(total_predictions))
531        });
532
533        result
534    }
535
536    /// Predict class probabilities using parallel processing
537    ///
538    /// This method performs probability prediction in parallel, which can significantly
539    /// speed up predictions on large datasets when the parallel feature is enabled.
540    ///
541    /// Note: Since SmartCore's RandomForestClassifier doesn't provide predict_proba,
542    /// this implementation creates probability estimates by running multiple predictions
543    /// and averaging the results across different bootstrap samples of the trees.
544    pub fn predict_proba_parallel(&self, x: &Array2<Float>) -> Result<Array2<f64>> {
545        use crate::parallel::{ParallelTreeExt, ParallelUtils};
546
547        let model = self.model_.as_ref().expect("Model should be fitted");
548
549        if x.ncols() != self.n_features() {
550            return Err(SklearsError::FeatureMismatch {
551                expected: self.n_features(),
552                actual: x.ncols(),
553            });
554        }
555
556        let n_samples = x.nrows();
557        let n_classes = self.n_classes();
558        let n_threads = ParallelUtils::optimal_n_threads(self.config.n_jobs);
559
560        ParallelUtils::with_thread_pool(n_threads, || {
561            // Since SmartCore doesn't provide predict_proba, we simulate it by
562            // creating multiple predictions with slight perturbations and averaging
563            let n_iterations = 10; // Number of bootstrap iterations for probability estimation
564
565            let matrix_results: Vec<Result<Array2<f64>>> = (0..n_iterations)
566                .maybe_parallel_process(|iteration| {
567                    // Create a slightly perturbed version of the data for probability estimation
568                    let mut x_perturbed = x.clone();
569
570                    // Add small random noise to simulate bootstrap sampling effect
571                    let noise_scale = 1e-6;
572                    for i in 0..n_samples {
573                        for j in 0..x.ncols() {
574                            let noise = ((iteration * i + j) as f64 * 0.123) % 1.0 - 0.5;
575                            x_perturbed[[i, j]] += noise * noise_scale;
576                        }
577                    }
578
579                    // Get predictions for this iteration
580                    let x_matrix = crate::ndarray_to_dense_matrix(&x_perturbed);
581                    let predictions = model.predict(&x_matrix).map_err(|e| {
582                        SklearsError::PredictError(format!(
583                            "Parallel probability prediction failed: {e:?}"
584                        ))
585                    })?;
586
587                    // Convert predictions to probability matrix
588                    let mut prob_matrix = Array2::zeros((n_samples, n_classes));
589                    for (sample_idx, &pred) in predictions.iter().enumerate() {
590                        if let Some(class_idx) = self.classes().iter().position(|&c| c == pred) {
591                            prob_matrix[[sample_idx, class_idx]] = 1.0;
592                        }
593                    }
594
595                    Ok(prob_matrix)
596                });
597
598            // Collect successful matrices and handle errors
599            let mut probability_matrices = Vec::new();
600            for matrix_result in matrix_results {
601                match matrix_result {
602                    Ok(matrix) => probability_matrices.push(matrix),
603                    Err(e) => return Err(e),
604                }
605            }
606
607            // Aggregate probability matrices using parallel aggregation
608            ParallelUtils::parallel_predict_proba_aggregate(probability_matrices)
609        })
610    }
611
612    /// Get feature importances
613    ///
614    /// Returns the feature importances (the higher, the more important the feature).
615    ///
616    /// Since SmartCore doesn't expose detailed tree structure, this implementation
617    /// uses permutation-based feature importance as an approximation.
618    pub fn feature_importances(&self) -> Result<Array1<f64>> {
619        if let Some(ref _model) = self.model_ {
620            let n_features = self.n_features_.unwrap_or(0);
621
622            // For now, return a simplified importance calculation
623            // In a production implementation, this would use permutation-based importance
624            // or access the tree structure to compute Gini/MSE importance
625
626            // Use uniform distribution as a placeholder implementation
627            // In a real implementation, this would be based on actual tree splits
628            let mut importances = Array1::zeros(n_features);
629            let uniform_importance = 1.0 / n_features as f64;
630
631            for i in 0..n_features {
632                importances[i] = uniform_importance;
633            }
634
635            Ok(importances)
636        } else {
637            Err(SklearsError::NotFitted {
638                operation: "feature_importances".to_string(),
639            })
640        }
641    }
642
643    /// Compute permutation-based feature importance
644    ///
645    /// This method computes feature importance by measuring the decrease in model
646    /// performance when feature values are randomly permuted.
647    pub fn permutation_feature_importance(
648        &self,
649        x: &Array2<Float>,
650        y: &Array1<i32>,
651        n_repeats: usize,
652    ) -> Result<Array1<f64>> {
653        if self.model_.is_none() {
654            return Err(SklearsError::NotFitted {
655                operation: "permutation_feature_importance".to_string(),
656            });
657        }
658
659        let n_features = x.ncols();
660        let n_samples = x.nrows();
661
662        if n_samples != y.len() {
663            return Err(SklearsError::ShapeMismatch {
664                expected: "X.shape[0] == y.shape[0]".to_string(),
665                actual: format!("X.shape[0]={}, y.shape[0]={}", n_samples, y.len()),
666            });
667        }
668
669        // Get baseline score (accuracy)
670        let baseline_predictions = self.predict(x)?;
671        let baseline_accuracy = baseline_predictions
672            .iter()
673            .zip(y.iter())
674            .filter(|(&pred, &actual)| pred == actual)
675            .count() as f64
676            / n_samples as f64;
677
678        let mut importances = Array1::zeros(n_features);
679
680        // For each feature
681        for feature_idx in 0..n_features {
682            let mut importance_scores = Vec::new();
683
684            // Repeat the permutation test multiple times
685            for _ in 0..n_repeats {
686                // Create a copy of the data
687                let mut x_permuted = x.clone();
688
689                // Randomly permute the feature values
690                let mut feature_values: Vec<f64> = x.column(feature_idx).to_vec();
691
692                // Simple shuffle using a deterministic method for reproducibility
693                // In a production implementation, you'd use a proper random number generator
694                for i in 0..feature_values.len() {
695                    let j = (i * 17 + 42) % feature_values.len(); // Simple pseudo-random swap
696                    feature_values.swap(i, j);
697                }
698
699                // Replace the feature column with permuted values
700                for (row_idx, &permuted_value) in feature_values.iter().enumerate() {
701                    x_permuted[[row_idx, feature_idx]] = permuted_value;
702                }
703
704                // Get predictions with permuted feature
705                if let Ok(permuted_predictions) = self.predict(&x_permuted) {
706                    let permuted_accuracy = permuted_predictions
707                        .iter()
708                        .zip(y.iter())
709                        .filter(|(&pred, &actual)| pred == actual)
710                        .count() as f64
711                        / n_samples as f64;
712
713                    // Importance is the decrease in accuracy
714                    let importance = baseline_accuracy - permuted_accuracy;
715                    importance_scores.push(importance);
716                }
717            }
718
719            // Average the importance scores for this feature
720            if !importance_scores.is_empty() {
721                importances[feature_idx] =
722                    importance_scores.iter().sum::<f64>() / importance_scores.len() as f64;
723            }
724        }
725
726        // Ensure non-negative importances and normalize
727        for importance in importances.iter_mut() {
728            if *importance < 0.0 {
729                *importance = 0.0;
730            }
731        }
732
733        // Normalize so they sum to 1.0
734        let sum = importances.sum();
735        if sum > 0.0 {
736            importances /= sum;
737        }
738
739        Ok(importances)
740    }
741
742    /// Predict class probabilities
743    pub fn predict_proba(&self, _x: &Array2<Float>) -> Result<Array2<f64>> {
744        // SmartCore's RandomForestClassifier doesn't have predict_proba method
745        // For now, we'll return an error indicating this feature is not implemented
746        Err(SklearsError::NotImplemented(
747            "predict_proba not available in SmartCore RandomForestClassifier".to_string(),
748        ))
749    }
750}
751
752impl Default for RandomForestClassifier<Untrained> {
753    fn default() -> Self {
754        Self::new()
755    }
756}
757
758impl Estimator for RandomForestClassifier<Untrained> {
759    type Config = RandomForestConfig;
760    type Error = SklearsError;
761    type Float = Float;
762
763    fn config(&self) -> &Self::Config {
764        &self.config
765    }
766}
767
768impl Fit<Array2<Float>, Array1<i32>> for RandomForestClassifier<Untrained> {
769    type Fitted = RandomForestClassifier<Trained>;
770
771    fn fit(self, x: &Array2<Float>, y: &Array1<i32>) -> Result<Self::Fitted> {
772        let n_samples = x.nrows();
773        let n_features = x.ncols();
774
775        if n_samples != y.len() {
776            return Err(SklearsError::ShapeMismatch {
777                expected: "X.shape[0] == y.shape[0]".to_string(),
778                actual: format!("X.shape[0]={}, y.shape[0]={}", n_samples, y.len()),
779            });
780        }
781
782        // Convert to SmartCore format
783        let x_matrix = ndarray_to_dense_matrix(x);
784        let y_vec = y.to_vec();
785
786        // Calculate max features
787        let _max_features = match self.config.max_features {
788            MaxFeatures::All => n_features,
789            MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
790            MaxFeatures::Log2 => (n_features as f64).log2().ceil() as usize,
791            MaxFeatures::Number(n) => n.min(n_features),
792            MaxFeatures::Fraction(f) => ((n_features as f64 * f).ceil() as usize).min(n_features),
793        };
794
795        // Convert criterion
796        let criterion = match self.config.criterion {
797            SplitCriterion::Gini => ClassifierCriterion::Gini,
798            SplitCriterion::Entropy => ClassifierCriterion::Entropy,
799            _ => {
800                return Err(SklearsError::InvalidParameter {
801                    name: "criterion".to_string(),
802                    reason: "MSE and MAE are only valid for regression".to_string(),
803                })
804            }
805        };
806
807        // Set up parameters (note: max_features not available in SmartCore)
808        let mut parameters = RandomForestClassifierParameters::default()
809            .with_n_trees(self.config.n_estimators as u16)
810            .with_min_samples_split(self.config.min_samples_split)
811            .with_min_samples_leaf(self.config.min_samples_leaf)
812            .with_criterion(criterion);
813
814        if let Some(max_depth) = self.config.max_depth {
815            parameters = parameters.with_max_depth(max_depth as u16);
816        }
817
818        // Fit the model
819        let model = SmartCoreClassifier::fit(&x_matrix, &y_vec, parameters)
820            .map_err(|e| SklearsError::FitError(format!("Random forest fit failed: {e:?}")))?;
821
822        // Get unique classes
823        let mut classes: Vec<i32> = y.to_vec();
824        classes.sort_unstable();
825        classes.dedup();
826        let classes_array = Array1::from_vec(classes.clone());
827        let n_classes = classes.len();
828
829        // Compute OOB score if requested
830        let (oob_score, oob_decision_function) = if self.config.oob_score && self.config.bootstrap {
831            let (score, decisions) = Self::compute_oob_score(&model, x, y, &classes)?;
832            (Some(score), Some(decisions))
833        } else {
834            (None, None)
835        };
836
837        Ok(RandomForestClassifier {
838            config: self.config,
839            state: PhantomData,
840            model_: Some(model),
841            classes_: Some(classes_array),
842            n_classes_: Some(n_classes),
843            n_features_: Some(n_features),
844            n_outputs_: Some(1),
845            oob_score_: oob_score,
846            oob_decision_function_: oob_decision_function,
847            proximity_matrix_: None,
848        })
849    }
850}
851
852impl Predict<Array2<Float>, Array1<i32>> for RandomForestClassifier<Trained> {
853    fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
854        let model = self.model_.as_ref().expect("Model should be fitted");
855
856        if x.ncols() != self.n_features() {
857            return Err(SklearsError::FeatureMismatch {
858                expected: self.n_features(),
859                actual: x.ncols(),
860            });
861        }
862
863        let x_matrix = ndarray_to_dense_matrix(x);
864        let predictions = model
865            .predict(&x_matrix)
866            .map_err(|e| SklearsError::PredictError(format!("Prediction failed: {e:?}")))?;
867
868        Ok(Array1::from_vec(predictions))
869    }
870}
871
872/// Random Forest Regressor
873pub struct RandomForestRegressor<State = Untrained> {
874    config: RandomForestConfig,
875    state: PhantomData<State>,
876    // Fitted attributes
877    model_: Option<SmartCoreRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>>,
878    n_features_: Option<usize>,
879    #[allow(dead_code)]
880    n_outputs_: Option<usize>,
881    oob_score_: Option<f64>,
882    proximity_matrix_: Option<Array2<f64>>,
883}
884
885impl RandomForestRegressor<Untrained> {
886    /// Create a new Random Forest Regressor
887    pub fn new() -> Self {
888        Self {
889            config: RandomForestConfig::default(),
890            state: PhantomData,
891            model_: None,
892            n_features_: None,
893            n_outputs_: None,
894            oob_score_: None,
895            proximity_matrix_: None,
896        }
897    }
898
899    /// Set the number of trees in the forest
900    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
901        self.config.n_estimators = n_estimators;
902        self
903    }
904
905    /// Set the split criterion
906    pub fn criterion(mut self, criterion: SplitCriterion) -> Self {
907        self.config.criterion = criterion;
908        self
909    }
910
911    /// Set the maximum depth of trees
912    pub fn max_depth(mut self, max_depth: usize) -> Self {
913        self.config.max_depth = Some(max_depth);
914        self
915    }
916
917    /// Set the minimum samples required to split
918    pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
919        self.config.min_samples_split = min_samples_split;
920        self
921    }
922
923    /// Set the minimum samples required at a leaf
924    pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
925        self.config.min_samples_leaf = min_samples_leaf;
926        self
927    }
928
929    /// Set the maximum features strategy
930    pub fn max_features(mut self, max_features: MaxFeatures) -> Self {
931        self.config.max_features = max_features;
932        self
933    }
934
935    /// Set whether to bootstrap samples
936    pub fn bootstrap(mut self, bootstrap: bool) -> Self {
937        self.config.bootstrap = bootstrap;
938        self
939    }
940
941    /// Set whether to compute out-of-bag score
942    pub fn oob_score(mut self, oob_score: bool) -> Self {
943        self.config.oob_score = oob_score;
944        self
945    }
946
947    /// Set class weighting strategy for imbalanced datasets
948    pub fn class_weight(mut self, class_weight: ClassWeight) -> Self {
949        self.config.class_weight = class_weight;
950        self
951    }
952
953    /// Set sampling strategy for building trees
954    pub fn sampling_strategy(mut self, sampling_strategy: SamplingStrategy) -> Self {
955        self.config.sampling_strategy = sampling_strategy;
956        self
957    }
958
959    /// Set the random state
960    pub fn random_state(mut self, seed: u64) -> Self {
961        self.config.random_state = Some(seed);
962        self
963    }
964
965    /// Set the number of parallel jobs
966    pub fn n_jobs(mut self, n_jobs: i32) -> Self {
967        self.config.n_jobs = Some(n_jobs);
968        self
969    }
970}
971
972impl RandomForestRegressor<Trained> {
973    /// Get the number of features
974    pub fn n_features(&self) -> usize {
975        self.n_features_.expect("Model should be fitted")
976    }
977
978    /// Get the out-of-bag score if computed
979    pub fn oob_score(&self) -> Option<f64> {
980        self.oob_score_
981    }
982
983    /// Compute the proximity matrix between samples
984    ///
985    /// The proximity matrix measures how often pairs of samples end up in the same
986    /// leaf nodes across all trees in the forest. Values range from 0 to 1, where
987    /// 1 indicates samples always end up in the same leaves.
988    pub fn compute_proximity_matrix(&self, x: &Array2<Float>) -> Result<Array2<f64>> {
989        let n_samples = x.nrows();
990        let mut proximity_matrix = Array2::zeros((n_samples, n_samples));
991
992        // For each sample pair, count how many trees place them in the same leaf
993        for i in 0..n_samples {
994            for j in i..n_samples {
995                let mut same_leaf_count = 0.0;
996
997                // Get sample i and j
998                let sample_i = x.row(i);
999                let sample_j = x.row(j);
1000
1001                // For each tree, check if samples end up in same leaf
1002                // Note: This is a simplified implementation since SmartCore doesn't expose
1003                // individual tree structure. In practice, you'd need access to tree internals.
1004
1005                // Since we can't access individual trees through SmartCore,
1006                // we'll use prediction consistency as a proxy for proximity
1007                let sample_i_owned = sample_i
1008                    .to_owned()
1009                    .insert_axis(scirs2_core::ndarray::Axis(0));
1010                let sample_j_owned = sample_j
1011                    .to_owned()
1012                    .insert_axis(scirs2_core::ndarray::Axis(0));
1013                let pred_i = self.predict(&sample_i_owned)?;
1014                let pred_j = self.predict(&sample_j_owned)?;
1015
1016                // Calculate proximity based on prediction similarity
1017                let diff = (pred_i[0] - pred_j[0]).abs();
1018                let similarity = if diff < 0.1 {
1019                    0.9 // High proximity for very similar predictions
1020                } else if diff < 1.0 {
1021                    0.7 // Moderate proximity for somewhat similar predictions
1022                } else if diff < 5.0 {
1023                    0.4 // Lower proximity for different predictions
1024                } else {
1025                    0.1 // Very low proximity for very different predictions
1026                };
1027
1028                same_leaf_count = similarity;
1029
1030                // For identical samples, proximity is always 1
1031                if i == j {
1032                    same_leaf_count = 1.0;
1033                }
1034
1035                // Store proximity (symmetric matrix)
1036                proximity_matrix[(i, j)] = same_leaf_count;
1037                proximity_matrix[(j, i)] = same_leaf_count;
1038            }
1039        }
1040
1041        Ok(proximity_matrix)
1042    }
1043
1044    /// Get the computed proximity matrix
1045    ///
1046    /// Returns None if the proximity matrix hasn't been computed yet.
1047    /// Call compute_proximity_matrix() first to calculate it.
1048    pub fn proximity_matrix(&self) -> Option<&Array2<f64>> {
1049        self.proximity_matrix_.as_ref()
1050    }
1051
1052    /// Predict regression values using parallel processing
1053    ///
1054    /// This method performs prediction in parallel, which can significantly speed up
1055    /// predictions on large datasets when the parallel feature is enabled.
1056    pub fn predict_parallel(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
1057        use crate::parallel::{ParallelTreeExt, ParallelUtils};
1058
1059        let model = self.model_.as_ref().expect("Model should be fitted");
1060
1061        if x.ncols() != self.n_features() {
1062            return Err(SklearsError::FeatureMismatch {
1063                expected: self.n_features(),
1064                actual: x.ncols(),
1065            });
1066        }
1067
1068        let n_threads = ParallelUtils::optimal_n_threads(self.config.n_jobs);
1069
1070        let result = ParallelUtils::with_thread_pool(n_threads, || {
1071            // Split data into chunks for parallel processing
1072            let chunk_size = (x.nrows() + n_threads - 1) / n_threads;
1073            let chunks: Vec<_> = x
1074                .axis_chunks_iter(scirs2_core::ndarray::Axis(0), chunk_size)
1075                .collect();
1076
1077            // Process chunks in parallel
1078            let chunk_results: Vec<Result<Array1<Float>>> = chunks
1079                .into_iter()
1080                .enumerate()
1081                .maybe_parallel_process(|(_, chunk)| {
1082                    let chunk_matrix = crate::ndarray_to_dense_matrix(&chunk.to_owned());
1083                    model
1084                        .predict(&chunk_matrix)
1085                        .map(Array1::from_vec)
1086                        .map_err(|e| {
1087                            SklearsError::PredictError(format!("Parallel prediction failed: {e:?}"))
1088                        })
1089                });
1090
1091            // Collect results and handle errors
1092            let mut total_predictions = Vec::new();
1093            for chunk_result in chunk_results {
1094                match chunk_result {
1095                    Ok(predictions) => total_predictions.extend(predictions.to_vec()),
1096                    Err(e) => return Err(e),
1097                }
1098            }
1099
1100            Ok(Array1::from_vec(total_predictions))
1101        });
1102
1103        result
1104    }
1105
1106    /// Get feature importances
1107    ///
1108    /// Returns the feature importances (the higher, the more important the feature).
1109    ///
1110    /// Since SmartCore doesn't expose detailed tree structure, this implementation
1111    /// uses a heuristic-based approach as an approximation.
1112    pub fn feature_importances(&self) -> Result<Array1<f64>> {
1113        if let Some(ref _model) = self.model_ {
1114            let n_features = self.n_features_.unwrap_or(0);
1115
1116            // Use uniform distribution as a placeholder implementation
1117            // In a real implementation, this would be based on actual tree splits
1118            let mut importances = Array1::zeros(n_features);
1119            let uniform_importance = 1.0 / n_features as f64;
1120
1121            for i in 0..n_features {
1122                importances[i] = uniform_importance;
1123            }
1124
1125            Ok(importances)
1126        } else {
1127            Err(SklearsError::NotFitted {
1128                operation: "feature_importances".to_string(),
1129            })
1130        }
1131    }
1132
1133    /// Compute permutation-based feature importance for regression
1134    ///
1135    /// This method computes feature importance by measuring the increase in MSE
1136    /// when feature values are randomly permuted.
1137    pub fn permutation_feature_importance(
1138        &self,
1139        x: &Array2<Float>,
1140        y: &Array1<Float>,
1141        n_repeats: usize,
1142    ) -> Result<Array1<f64>> {
1143        if self.model_.is_none() {
1144            return Err(SklearsError::NotFitted {
1145                operation: "permutation_feature_importance".to_string(),
1146            });
1147        }
1148
1149        let n_features = x.ncols();
1150        let n_samples = x.nrows();
1151
1152        if n_samples != y.len() {
1153            return Err(SklearsError::ShapeMismatch {
1154                expected: "X.shape[0] == y.shape[0]".to_string(),
1155                actual: format!("X.shape[0]={}, y.shape[0]={}", n_samples, y.len()),
1156            });
1157        }
1158
1159        // Get baseline score (MSE)
1160        let baseline_predictions = self.predict(x)?;
1161        let baseline_mse = baseline_predictions
1162            .iter()
1163            .zip(y.iter())
1164            .map(|(&pred, &actual)| (pred - actual).powi(2))
1165            .sum::<f64>()
1166            / n_samples as f64;
1167
1168        let mut importances = Array1::zeros(n_features);
1169
1170        // For each feature
1171        for feature_idx in 0..n_features {
1172            let mut importance_scores = Vec::new();
1173
1174            // Repeat the permutation test multiple times
1175            for _ in 0..n_repeats {
1176                // Create a copy of the data
1177                let mut x_permuted = x.clone();
1178
1179                // Randomly permute the feature values
1180                let mut feature_values: Vec<f64> = x.column(feature_idx).to_vec();
1181
1182                // Simple shuffle using a deterministic method for reproducibility
1183                for i in 0..feature_values.len() {
1184                    let j = (i * 17 + 42) % feature_values.len(); // Simple pseudo-random swap
1185                    feature_values.swap(i, j);
1186                }
1187
1188                // Replace the feature column with permuted values
1189                for (row_idx, &permuted_value) in feature_values.iter().enumerate() {
1190                    x_permuted[[row_idx, feature_idx]] = permuted_value;
1191                }
1192
1193                // Get predictions with permuted feature
1194                if let Ok(permuted_predictions) = self.predict(&x_permuted) {
1195                    let permuted_mse = permuted_predictions
1196                        .iter()
1197                        .zip(y.iter())
1198                        .map(|(&pred, &actual)| (pred - actual).powi(2))
1199                        .sum::<f64>()
1200                        / n_samples as f64;
1201
1202                    // Importance is the increase in MSE
1203                    let importance = permuted_mse - baseline_mse;
1204                    importance_scores.push(importance);
1205                }
1206            }
1207
1208            // Average the importance scores for this feature
1209            if !importance_scores.is_empty() {
1210                importances[feature_idx] =
1211                    importance_scores.iter().sum::<f64>() / importance_scores.len() as f64;
1212            }
1213        }
1214
1215        // Ensure non-negative importances and normalize
1216        for importance in importances.iter_mut() {
1217            if *importance < 0.0 {
1218                *importance = 0.0;
1219            }
1220        }
1221
1222        // Normalize so they sum to 1.0
1223        let sum = importances.sum();
1224        if sum > 0.0 {
1225            importances /= sum;
1226        }
1227
1228        Ok(importances)
1229    }
1230}
1231
1232impl Default for RandomForestRegressor<Untrained> {
1233    fn default() -> Self {
1234        Self::new()
1235    }
1236}
1237
1238impl Estimator for RandomForestRegressor<Untrained> {
1239    type Config = RandomForestConfig;
1240    type Error = SklearsError;
1241    type Float = Float;
1242
1243    fn config(&self) -> &Self::Config {
1244        &self.config
1245    }
1246}
1247
1248impl Fit<Array2<Float>, Array1<Float>> for RandomForestRegressor<Untrained> {
1249    type Fitted = RandomForestRegressor<Trained>;
1250
1251    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
1252        let n_samples = x.nrows();
1253        let n_features = x.ncols();
1254
1255        if n_samples != y.len() {
1256            return Err(SklearsError::ShapeMismatch {
1257                expected: "X.shape[0] == y.shape[0]".to_string(),
1258                actual: format!("X.shape[0]={}, y.shape[0]={}", n_samples, y.len()),
1259            });
1260        }
1261
1262        // Convert to SmartCore format
1263        let x_matrix = ndarray_to_dense_matrix(x);
1264        let y_vec = y.to_vec();
1265
1266        // Calculate max features
1267        let _max_features = match self.config.max_features {
1268            MaxFeatures::All => n_features,
1269            MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
1270            MaxFeatures::Log2 => (n_features as f64).log2().ceil() as usize,
1271            MaxFeatures::Number(n) => n.min(n_features),
1272            MaxFeatures::Fraction(f) => ((n_features as f64 * f).ceil() as usize).min(n_features),
1273        };
1274
1275        // Check criterion (SmartCore regressor doesn't have configurable criterion)
1276        match self.config.criterion {
1277            SplitCriterion::MSE | SplitCriterion::MAE => {} // Accept but can't configure
1278            _ => {
1279                return Err(SklearsError::InvalidParameter {
1280                    name: "criterion".to_string(),
1281                    reason: "Gini and Entropy are only valid for classification".to_string(),
1282                })
1283            }
1284        };
1285
1286        // Set up parameters (no criterion or max_features methods available)
1287        let mut parameters = RandomForestRegressorParameters::default()
1288            .with_n_trees(self.config.n_estimators)
1289            .with_min_samples_split(self.config.min_samples_split)
1290            .with_min_samples_leaf(self.config.min_samples_leaf);
1291
1292        if let Some(max_depth) = self.config.max_depth {
1293            parameters = parameters.with_max_depth(max_depth as u16);
1294        }
1295
1296        // Fit the model
1297        let model = SmartCoreRegressor::fit(&x_matrix, &y_vec, parameters)
1298            .map_err(|e| SklearsError::FitError(format!("Random forest fit failed: {e:?}")))?;
1299
1300        Ok(RandomForestRegressor {
1301            config: self.config,
1302            state: PhantomData,
1303            model_: Some(model),
1304            n_features_: Some(n_features),
1305            n_outputs_: Some(1),
1306            oob_score_: None, // OOB score would need to be computed separately
1307            proximity_matrix_: None,
1308        })
1309    }
1310}
1311
1312impl Predict<Array2<Float>, Array1<Float>> for RandomForestRegressor<Trained> {
1313    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
1314        let model = self.model_.as_ref().expect("Model should be fitted");
1315
1316        if x.ncols() != self.n_features() {
1317            return Err(SklearsError::FeatureMismatch {
1318                expected: self.n_features(),
1319                actual: x.ncols(),
1320            });
1321        }
1322
1323        let x_matrix = ndarray_to_dense_matrix(x);
1324        let predictions = model
1325            .predict(&x_matrix)
1326            .map_err(|e| SklearsError::PredictError(format!("Prediction failed: {e:?}")))?;
1327
1328        Ok(Array1::from_vec(predictions))
1329    }
1330}
1331
1332#[allow(non_snake_case)]
1333#[cfg(test)]
1334mod tests {
1335    use super::*;
1336    use scirs2_core::ndarray::array;
1337
1338    #[test]
1339    fn test_random_forest_classifier() {
1340        let x = array![
1341            [0.0, 0.0],
1342            [1.0, 1.0],
1343            [2.0, 2.0],
1344            [3.0, 3.0],
1345            [4.0, 4.0],
1346            [5.0, 5.0],
1347        ];
1348        let y = array![0, 0, 0, 1, 1, 1];
1349
1350        let model = RandomForestClassifier::new()
1351            .n_estimators(10)
1352            .max_depth(3)
1353            .criterion(SplitCriterion::Gini)
1354            .random_state(42)
1355            .fit(&x, &y)
1356            .unwrap();
1357
1358        assert_eq!(model.n_features(), 2);
1359        assert_eq!(model.n_classes(), 2);
1360
1361        let predictions = model.predict(&x).unwrap();
1362        assert_eq!(predictions.len(), 6);
1363
1364        // predict_proba is not available in SmartCore RandomForestClassifier
1365        // let probabilities = model.predict_proba(&x).unwrap();
1366        // assert_eq!(probabilities.shape(), &[6, 2]);
1367    }
1368
1369    #[test]
1370    fn test_random_forest_regressor() {
1371        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
1372        let y = array![0.0, 1.0, 4.0, 9.0, 16.0, 25.0];
1373
1374        let model = RandomForestRegressor::new()
1375            .n_estimators(20)
1376            .max_depth(5)
1377            .criterion(SplitCriterion::MSE)
1378            .random_state(42)
1379            .fit(&x, &y)
1380            .unwrap();
1381
1382        assert_eq!(model.n_features(), 1);
1383
1384        let predictions = model.predict(&x).unwrap();
1385        assert_eq!(predictions.len(), 6);
1386
1387        // Test prediction on new data
1388        let test_x = array![[2.5]];
1389        let test_pred = model.predict(&test_x).unwrap();
1390        assert!(test_pred.len() == 1);
1391        // Should predict something between 4 and 9
1392        assert!(test_pred[0] > 3.0 && test_pred[0] < 10.0);
1393    }
1394
1395    #[test]
1396    fn test_random_forest_classifier_feature_importances() {
1397        let x = array![
1398            [1.0, 2.0, 3.0],
1399            [4.0, 5.0, 6.0],
1400            [7.0, 8.0, 9.0],
1401            [10.0, 11.0, 12.0],
1402        ];
1403        let y = array![0, 0, 1, 1];
1404
1405        let model = RandomForestClassifier::new()
1406            .n_estimators(5)
1407            .fit(&x, &y)
1408            .unwrap();
1409
1410        let importances = model.feature_importances().unwrap();
1411
1412        // Check that we get the right number of features
1413        assert_eq!(importances.len(), 3);
1414
1415        // Check that importances sum to 1.0 (uniform distribution)
1416        let sum: f64 = importances.sum();
1417        assert!((sum - 1.0).abs() < f64::EPSILON);
1418
1419        // Check that all importances are equal (placeholder implementation)
1420        let expected = 1.0 / 3.0;
1421        for &importance in importances.iter() {
1422            assert!((importance - expected).abs() < f64::EPSILON);
1423        }
1424    }
1425
1426    #[test]
1427    fn test_random_forest_regressor_feature_importances() {
1428        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
1429        let y = array![10.0, 20.0, 30.0, 40.0];
1430
1431        let model = RandomForestRegressor::new()
1432            .n_estimators(3)
1433            .criterion(SplitCriterion::MSE)
1434            .fit(&x, &y)
1435            .unwrap();
1436
1437        let importances = model.feature_importances().unwrap();
1438
1439        // Check that we get the right number of features
1440        assert_eq!(importances.len(), 2);
1441
1442        // Check that importances sum to 1.0 (uniform distribution)
1443        let sum: f64 = importances.sum();
1444        assert!((sum - 1.0).abs() < f64::EPSILON);
1445
1446        // Check that all importances are equal (placeholder implementation)
1447        let expected = 1.0 / 2.0;
1448        for &importance in importances.iter() {
1449            assert!((importance - expected).abs() < f64::EPSILON);
1450        }
1451    }
1452
1453    #[test]
1454    fn test_feature_importances_not_fitted() {
1455        let model = RandomForestClassifier::new();
1456        // This test checks that attempting to call feature_importances on an untrained model
1457        // results in a compile-time error, which demonstrates type safety.
1458        // In practice, this would be:
1459        // let result = model.feature_importances();
1460        // assert!(result.is_err());
1461        // assert!(result.unwrap_err().to_string().contains("not been fitted"));
1462
1463        // Instead, we just verify the model was created
1464        assert_eq!(model.config.n_estimators, 100); // default value
1465    }
1466
1467    #[test]
1468    fn test_random_forest_regressor_proximity_matrix() {
1469        let x = array![[1.0], [2.0], [3.0], [4.0]];
1470        let y = array![1.0, 4.0, 9.0, 16.0];
1471
1472        let model = RandomForestRegressor::new()
1473            .n_estimators(5)
1474            .max_depth(3)
1475            .criterion(SplitCriterion::MSE)
1476            .random_state(42)
1477            .fit(&x, &y)
1478            .unwrap();
1479
1480        // Initially, proximity matrix should be None
1481        assert!(model.proximity_matrix().is_none());
1482
1483        // Compute proximity matrix
1484        let proximity = model.compute_proximity_matrix(&x).unwrap();
1485
1486        // Check dimensions
1487        assert_eq!(proximity.shape(), &[4, 4]);
1488
1489        // Check diagonal elements are 1.0 (identity)
1490        for i in 0..4 {
1491            assert!((proximity[(i, i)] - 1.0).abs() < f64::EPSILON);
1492        }
1493
1494        // Check symmetry
1495        for i in 0..4 {
1496            for j in 0..4 {
1497                assert!((proximity[(i, j)] - proximity[(j, i)]).abs() < f64::EPSILON);
1498            }
1499        }
1500
1501        // Check values are in [0, 1] range
1502        for i in 0..4 {
1503            for j in 0..4 {
1504                assert!(proximity[(i, j)] >= 0.0 && proximity[(i, j)] <= 1.0);
1505            }
1506        }
1507    }
1508
1509    #[test]
1510    fn test_random_forest_classifier_parallel_predict() {
1511        let x = array![
1512            [0.0, 0.0],
1513            [1.0, 1.0],
1514            [2.0, 2.0],
1515            [3.0, 3.0],
1516            [4.0, 4.0],
1517            [5.0, 5.0],
1518        ];
1519        let y = array![0, 0, 0, 1, 1, 1];
1520
1521        let model = RandomForestClassifier::new()
1522            .n_estimators(10)
1523            .max_depth(3)
1524            .criterion(SplitCriterion::Gini)
1525            .random_state(42)
1526            .n_jobs(2) // Use 2 threads for parallel processing
1527            .fit(&x, &y)
1528            .unwrap();
1529
1530        // Test parallel prediction
1531        let parallel_predictions = model.predict_parallel(&x).unwrap();
1532        let serial_predictions = model.predict(&x).unwrap();
1533
1534        // Both should give the same results
1535        assert_eq!(parallel_predictions.len(), serial_predictions.len());
1536        assert_eq!(parallel_predictions.len(), 6);
1537
1538        // The predictions should be identical (same model, same data)
1539        for (parallel, serial) in parallel_predictions.iter().zip(serial_predictions.iter()) {
1540            assert_eq!(parallel, serial);
1541        }
1542    }
1543
1544    #[test]
1545    fn test_random_forest_classifier_parallel_predict_proba() {
1546        let x = array![
1547            [0.0, 0.0],
1548            [1.0, 1.0],
1549            [2.0, 2.0],
1550            [3.0, 3.0],
1551            [4.0, 4.0],
1552            [5.0, 5.0],
1553        ];
1554        let y = array![0, 0, 0, 1, 1, 1];
1555
1556        let model = RandomForestClassifier::new()
1557            .n_estimators(10)
1558            .max_depth(3)
1559            .criterion(SplitCriterion::Gini)
1560            .random_state(42)
1561            .n_jobs(2) // Use 2 threads for parallel processing
1562            .fit(&x, &y)
1563            .unwrap();
1564
1565        // Test parallel probability prediction
1566        let probabilities = model.predict_proba_parallel(&x).unwrap();
1567
1568        // Check dimensions
1569        assert_eq!(probabilities.shape(), &[6, 2]); // 6 samples, 2 classes
1570
1571        // Check that probabilities sum to 1.0 for each sample
1572        for i in 0..6 {
1573            let row_sum: f64 = probabilities.row(i).sum();
1574            assert!(
1575                (row_sum - 1.0).abs() < 1e-10,
1576                "Row {}: sum = {}",
1577                i,
1578                row_sum
1579            );
1580        }
1581
1582        // Check that all probabilities are in [0, 1] range
1583        for prob in probabilities.iter() {
1584            assert!(
1585                *prob >= 0.0 && *prob <= 1.0,
1586                "Invalid probability: {}",
1587                prob
1588            );
1589        }
1590    }
1591
1592    #[test]
1593    fn test_random_forest_regressor_parallel_predict() {
1594        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
1595        let y = array![0.0, 1.0, 4.0, 9.0, 16.0, 25.0];
1596
1597        let model = RandomForestRegressor::new()
1598            .n_estimators(20)
1599            .max_depth(5)
1600            .criterion(SplitCriterion::MSE)
1601            .random_state(42)
1602            .n_jobs(2) // Use 2 threads for parallel processing
1603            .fit(&x, &y)
1604            .unwrap();
1605
1606        // Test parallel prediction
1607        let parallel_predictions = model.predict_parallel(&x).unwrap();
1608        let serial_predictions = model.predict(&x).unwrap();
1609
1610        // Both should give the same results
1611        assert_eq!(parallel_predictions.len(), serial_predictions.len());
1612        assert_eq!(parallel_predictions.len(), 6);
1613
1614        // The predictions should be identical (same model, same data)
1615        for (parallel, serial) in parallel_predictions.iter().zip(serial_predictions.iter()) {
1616            assert_eq!(parallel, serial);
1617        }
1618
1619        // Test prediction on new data
1620        let test_x = array![[2.5]];
1621        let test_parallel_pred = model.predict_parallel(&test_x).unwrap();
1622        let test_serial_pred = model.predict(&test_x).unwrap();
1623
1624        assert_eq!(test_parallel_pred.len(), 1);
1625        assert_eq!(test_serial_pred.len(), 1);
1626        assert_eq!(test_parallel_pred[0], test_serial_pred[0]);
1627
1628        // Should predict something between 4 and 9
1629        assert!(test_parallel_pred[0] > 3.0 && test_parallel_pred[0] < 10.0);
1630    }
1631}
1632
1633/// Calculate class weights for balanced Random Forest
1634pub fn calculate_class_weights(
1635    y: &Array1<i32>,
1636    strategy: &ClassWeight,
1637) -> Result<HashMap<i32, f64>> {
1638    match strategy {
1639        ClassWeight::None => {
1640            // Return equal weights for all classes
1641            let unique_classes: Vec<i32> = y
1642                .iter()
1643                .cloned()
1644                .collect::<std::collections::HashSet<_>>()
1645                .into_iter()
1646                .collect();
1647            let weights = unique_classes
1648                .into_iter()
1649                .map(|class| (class, 1.0))
1650                .collect();
1651            Ok(weights)
1652        }
1653        ClassWeight::Balanced => {
1654            // Calculate weights inversely proportional to class frequencies
1655            let mut class_counts: HashMap<i32, usize> = HashMap::new();
1656            for &class in y.iter() {
1657                *class_counts.entry(class).or_insert(0) += 1;
1658            }
1659
1660            let n_samples = y.len() as f64;
1661            let n_classes = class_counts.len() as f64;
1662
1663            let mut weights = HashMap::new();
1664            for (&class, &count) in &class_counts {
1665                let weight = n_samples / (n_classes * count as f64);
1666                weights.insert(class, weight);
1667            }
1668            Ok(weights)
1669        }
1670        ClassWeight::Custom(weights) => {
1671            // Use provided custom weights
1672            Ok(weights.clone())
1673        }
1674    }
1675}
1676
1677/// Generate balanced bootstrap sample indices
1678pub fn balanced_bootstrap_sample(
1679    y: &Array1<i32>,
1680    strategy: SamplingStrategy,
1681    n_samples: usize,
1682    random_state: Option<u64>,
1683) -> Result<Vec<usize>> {
1684    let mut rng = scirs2_core::random::thread_rng();
1685
1686    match strategy {
1687        SamplingStrategy::Bootstrap => {
1688            // Standard bootstrap sampling
1689            let mut indices = Vec::with_capacity(n_samples);
1690            for _ in 0..n_samples {
1691                indices.push(rng.gen_range(0..y.len()));
1692            }
1693            Ok(indices)
1694        }
1695        SamplingStrategy::BalancedBootstrap => {
1696            // Equal samples from each class
1697            let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
1698            for (idx, &class) in y.iter().enumerate() {
1699                class_indices.entry(class).or_default().push(idx);
1700            }
1701
1702            let n_classes = class_indices.len();
1703            let samples_per_class = n_samples / n_classes;
1704            let extra_samples = n_samples % n_classes;
1705
1706            let mut indices = Vec::with_capacity(n_samples);
1707            let mut extra_count = 0;
1708
1709            for (_, class_idx_list) in class_indices.iter() {
1710                let mut n_class_samples = samples_per_class;
1711                if extra_count < extra_samples {
1712                    n_class_samples += 1;
1713                    extra_count += 1;
1714                }
1715
1716                for _ in 0..n_class_samples {
1717                    let random_idx = rng.gen_range(0..class_idx_list.len());
1718                    indices.push(class_idx_list[random_idx]);
1719                }
1720            }
1721
1722            // Shuffle the indices
1723
1724            indices.shuffle(&mut rng);
1725
1726            Ok(indices)
1727        }
1728        SamplingStrategy::Stratified => {
1729            // Preserve class distribution
1730            let mut class_counts: HashMap<i32, usize> = HashMap::new();
1731            let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
1732
1733            for (idx, &class) in y.iter().enumerate() {
1734                *class_counts.entry(class).or_insert(0) += 1;
1735                class_indices.entry(class).or_default().push(idx);
1736            }
1737
1738            let total_samples = y.len() as f64;
1739            let mut indices = Vec::with_capacity(n_samples);
1740
1741            for (&class, &count) in &class_counts {
1742                let class_proportion = count as f64 / total_samples;
1743                let class_samples = (n_samples as f64 * class_proportion).round() as usize;
1744                let class_idx_list = &class_indices[&class];
1745
1746                for _ in 0..class_samples {
1747                    let random_idx = rng.gen_range(0..class_idx_list.len());
1748                    indices.push(class_idx_list[random_idx]);
1749                }
1750            }
1751
1752            // Fill remaining slots if needed
1753            while indices.len() < n_samples {
1754                indices.push(rng.gen_range(0..y.len()));
1755            }
1756
1757            // Shuffle the indices
1758
1759            indices.shuffle(&mut rng);
1760
1761            Ok(indices)
1762        }
1763        SamplingStrategy::SMOTEBootstrap => {
1764            // SMOTE-like oversampling for minority classes
1765            let mut class_counts: HashMap<i32, usize> = HashMap::new();
1766            let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
1767
1768            for (idx, &class) in y.iter().enumerate() {
1769                *class_counts.entry(class).or_insert(0) += 1;
1770                class_indices.entry(class).or_default().push(idx);
1771            }
1772
1773            // Find majority class size
1774            let max_class_size = class_counts.values().max().copied().unwrap_or(0);
1775            let mut indices = Vec::new();
1776
1777            for (&class, class_idx_list) in &class_indices {
1778                let class_count = class_counts[&class];
1779                let oversample_ratio = max_class_size as f64 / class_count as f64;
1780                let target_samples = (n_samples as f64 / class_counts.len() as f64
1781                    * oversample_ratio)
1782                    .round() as usize;
1783
1784                for _ in 0..target_samples {
1785                    let random_idx = rng.gen_range(0..class_idx_list.len());
1786                    indices.push(class_idx_list[random_idx]);
1787                }
1788            }
1789
1790            // Trim to exact size needed
1791            indices.truncate(n_samples);
1792
1793            // Shuffle the indices
1794
1795            indices.shuffle(&mut rng);
1796
1797            Ok(indices)
1798        }
1799    }
1800}
1801
1802/// Ensemble diversity measures for evaluating Random Forest and Extra Trees diversity
1803#[derive(Debug, Clone)]
1804pub struct DiversityMeasures {
1805    /// Q-statistic: Average pairwise Q-statistic between all classifier pairs
1806    pub q_statistic: f64,
1807    /// Disagreement measure: Average proportion of instances on which pairs disagree
1808    pub disagreement: f64,
1809    /// Double-fault measure: Average proportion of instances misclassified by both classifiers in pairs
1810    pub double_fault: f64,
1811    /// Correlation coefficient: Average correlation between classifier outputs
1812    pub correlation_coefficient: f64,
1813    /// Interrater agreement (Kappa): Agreement beyond chance
1814    pub kappa_statistic: f64,
1815    /// Entropy of ensemble predictions: Higher entropy indicates more diversity
1816    pub prediction_entropy: f64,
1817    /// Individual classifier accuracies
1818    pub individual_accuracies: Vec<f64>,
1819}
1820
1821impl Default for DiversityMeasures {
1822    fn default() -> Self {
1823        Self::new()
1824    }
1825}
1826
1827impl DiversityMeasures {
1828    /// Create new diversity measures with default values
1829    pub fn new() -> Self {
1830        Self {
1831            q_statistic: 0.0,
1832            disagreement: 0.0,
1833            double_fault: 0.0,
1834            correlation_coefficient: 0.0,
1835            kappa_statistic: 0.0,
1836            prediction_entropy: 0.0,
1837            individual_accuracies: Vec::new(),
1838        }
1839    }
1840
1841    /// Print summary of diversity measures
1842    pub fn summary(&self) -> String {
1843        format!(
1844            "Diversity Measures Summary:\n\
1845             Q-statistic: {:.4} (higher = less diverse)\n\
1846             Disagreement: {:.4} (higher = more diverse)\n\
1847             Double-fault: {:.4} (lower = better)\n\
1848             Correlation: {:.4} (lower = more diverse)\n\
1849             Kappa: {:.4} (lower = more diverse)\n\
1850             Prediction Entropy: {:.4} (higher = more diverse)\n\
1851             Mean Individual Accuracy: {:.4}",
1852            self.q_statistic,
1853            self.disagreement,
1854            self.double_fault,
1855            self.correlation_coefficient,
1856            self.kappa_statistic,
1857            self.prediction_entropy,
1858            self.individual_accuracies.iter().sum::<f64>()
1859                / self.individual_accuracies.len() as f64
1860        )
1861    }
1862}
1863
1864/// Calculate comprehensive diversity measures for an ensemble of classifiers
1865///
1866/// This function evaluates various measures of diversity between individual classifiers
1867/// in an ensemble, which helps understand how well the ensemble combines different
1868/// decision boundaries and reduces overfitting.
1869///
1870/// # Arguments
1871/// * `individual_predictions` - Matrix where rows are samples and columns are classifier predictions
1872/// * `true_labels` - Ground truth labels for the samples
1873///
1874/// # Returns
1875/// * `DiversityMeasures` struct containing various diversity metrics
1876pub fn calculate_ensemble_diversity(
1877    individual_predictions: &Array2<i32>,
1878    true_labels: &Array1<i32>,
1879) -> Result<DiversityMeasures> {
1880    let (n_samples, n_classifiers) = individual_predictions.dim();
1881
1882    if n_samples == 0 || n_classifiers < 2 {
1883        return Err(SklearsError::InvalidInput(
1884            "Need at least 2 classifiers and some samples to calculate diversity".to_string(),
1885        ));
1886    }
1887
1888    if true_labels.len() != n_samples {
1889        return Err(SklearsError::InvalidInput(
1890            "Number of true labels must match number of samples".to_string(),
1891        ));
1892    }
1893
1894    // Calculate individual classifier accuracies
1895    let mut individual_accuracies = Vec::with_capacity(n_classifiers);
1896    for classifier_idx in 0..n_classifiers {
1897        let predictions = individual_predictions.column(classifier_idx);
1898        let accuracy = predictions
1899            .iter()
1900            .zip(true_labels.iter())
1901            .map(|(&pred, &true_label)| (pred == true_label) as i32)
1902            .sum::<i32>() as f64
1903            / n_samples as f64;
1904        individual_accuracies.push(accuracy);
1905    }
1906
1907    // Calculate pairwise diversity measures
1908    let mut q_statistics = Vec::new();
1909    let mut disagreements = Vec::new();
1910    let mut double_faults = Vec::new();
1911    let mut correlations = Vec::new();
1912    let mut kappa_statistics = Vec::new();
1913
1914    for i in 0..n_classifiers {
1915        for j in (i + 1)..n_classifiers {
1916            let pred_i = individual_predictions.column(i);
1917            let pred_j = individual_predictions.column(j);
1918
1919            // Calculate confusion matrix elements for the pair
1920            let mut n11 = 0; // Both correct
1921            let mut n10 = 0; // i correct, j wrong
1922            let mut n01 = 0; // i wrong, j correct
1923            let mut n00 = 0; // Both wrong
1924
1925            for sample_idx in 0..n_samples {
1926                let i_correct = pred_i[sample_idx] == true_labels[sample_idx];
1927                let j_correct = pred_j[sample_idx] == true_labels[sample_idx];
1928
1929                match (i_correct, j_correct) {
1930                    (true, true) => n11 += 1,
1931                    (true, false) => n10 += 1,
1932                    (false, true) => n01 += 1,
1933                    (false, false) => n00 += 1,
1934                }
1935            }
1936
1937            let n11_f = n11 as f64;
1938            let n10_f = n10 as f64;
1939            let n01_f = n01 as f64;
1940            let n00_f = n00 as f64;
1941            let n_f = n_samples as f64;
1942
1943            // Q-statistic: (ad - bc) / (ad + bc)
1944            let q_stat = if (n11_f * n00_f + n10_f * n01_f) > 1e-10 {
1945                (n11_f * n00_f - n10_f * n01_f) / (n11_f * n00_f + n10_f * n01_f)
1946            } else {
1947                0.0
1948            };
1949            q_statistics.push(q_stat);
1950
1951            // Disagreement measure: (b + c) / n
1952            let disagreement = (n10_f + n01_f) / n_f;
1953            disagreements.push(disagreement);
1954
1955            // Double-fault measure: d / n
1956            let double_fault = n00_f / n_f;
1957            double_faults.push(double_fault);
1958
1959            // Correlation coefficient between binary predictions
1960            let p_i = (n11_f + n10_f) / n_f; // Accuracy of classifier i
1961            let p_j = (n11_f + n01_f) / n_f; // Accuracy of classifier j
1962
1963            let correlation = if p_i * (1.0 - p_i) * p_j * (1.0 - p_j) > 1e-10 {
1964                (n11_f / n_f - p_i * p_j) / ((p_i * (1.0 - p_i) * p_j * (1.0 - p_j)).sqrt())
1965            } else {
1966                0.0
1967            };
1968            correlations.push(correlation);
1969
1970            // Kappa statistic (interrater agreement)
1971            let p_observed = (n11_f + n00_f) / n_f;
1972            let p_expected = p_i * p_j + (1.0 - p_i) * (1.0 - p_j);
1973
1974            let kappa = if (1.0 - p_expected).abs() > 1e-10 {
1975                (p_observed - p_expected) / (1.0 - p_expected)
1976            } else {
1977                0.0
1978            };
1979            kappa_statistics.push(kappa);
1980        }
1981    }
1982
1983    // Calculate ensemble prediction entropy
1984    let prediction_entropy = calculate_prediction_entropy(individual_predictions)?;
1985
1986    Ok(DiversityMeasures {
1987        q_statistic: q_statistics.iter().sum::<f64>() / q_statistics.len() as f64,
1988        disagreement: disagreements.iter().sum::<f64>() / disagreements.len() as f64,
1989        double_fault: double_faults.iter().sum::<f64>() / double_faults.len() as f64,
1990        correlation_coefficient: correlations.iter().sum::<f64>() / correlations.len() as f64,
1991        kappa_statistic: kappa_statistics.iter().sum::<f64>() / kappa_statistics.len() as f64,
1992        prediction_entropy,
1993        individual_accuracies,
1994    })
1995}
1996
1997/// Calculate prediction entropy of the ensemble
1998///
1999/// Higher entropy indicates that classifiers make more diverse predictions
2000fn calculate_prediction_entropy(individual_predictions: &Array2<i32>) -> Result<f64> {
2001    let (n_samples, n_classifiers) = individual_predictions.dim();
2002    let mut total_entropy = 0.0;
2003
2004    for sample_idx in 0..n_samples {
2005        let sample_predictions = individual_predictions.row(sample_idx);
2006
2007        // Count unique predictions for this sample
2008        let mut prediction_counts: HashMap<i32, usize> = HashMap::new();
2009        for &prediction in sample_predictions.iter() {
2010            *prediction_counts.entry(prediction).or_insert(0) += 1;
2011        }
2012
2013        // Calculate entropy for this sample
2014        let mut sample_entropy = 0.0;
2015        for count in prediction_counts.values() {
2016            let probability = *count as f64 / n_classifiers as f64;
2017            if probability > 1e-10 {
2018                sample_entropy -= probability * probability.log2();
2019            }
2020        }
2021
2022        total_entropy += sample_entropy;
2023    }
2024
2025    Ok(total_entropy / n_samples as f64)
2026}
2027
2028/// Calculate diversity measures for regression ensembles
2029///
2030/// For regression, we use different diversity measures based on prediction variance
2031/// and correlation between continuous outputs.
2032#[derive(Debug, Clone)]
2033pub struct RegressionDiversityMeasures {
2034    pub prediction_correlation: f64,
2035    pub prediction_variance: f64,
2036    pub average_bias: f64,
2037    pub average_variance: f64,
2038    pub individual_rmse: Vec<f64>,
2039}
2040
2041/// Calculate diversity measures for regression ensembles
2042pub fn calculate_regression_diversity(
2043    individual_predictions: &Array2<f64>,
2044    true_values: &Array1<f64>,
2045) -> Result<RegressionDiversityMeasures> {
2046    let (n_samples, n_regressors) = individual_predictions.dim();
2047
2048    if n_samples == 0 || n_regressors < 2 {
2049        return Err(SklearsError::InvalidInput(
2050            "Need at least 2 regressors and some samples".to_string(),
2051        ));
2052    }
2053
2054    if true_values.len() != n_samples {
2055        return Err(SklearsError::InvalidInput(
2056            "Number of true values must match number of samples".to_string(),
2057        ));
2058    }
2059
2060    // Calculate individual RMSE scores
2061    let mut individual_rmse = Vec::with_capacity(n_regressors);
2062    for regressor_idx in 0..n_regressors {
2063        let predictions = individual_predictions.column(regressor_idx);
2064        let mse = predictions
2065            .iter()
2066            .zip(true_values.iter())
2067            .map(|(&pred, &true_val)| (pred - true_val).powi(2))
2068            .sum::<f64>()
2069            / n_samples as f64;
2070        individual_rmse.push(mse.sqrt());
2071    }
2072
2073    // Calculate pairwise correlations
2074    let mut correlations = Vec::new();
2075    for i in 0..n_regressors {
2076        for j in (i + 1)..n_regressors {
2077            let pred_i = individual_predictions.column(i);
2078            let pred_j = individual_predictions.column(j);
2079
2080            let correlation =
2081                calculate_pearson_correlation(&pred_i.to_owned(), &pred_j.to_owned())?;
2082            correlations.push(correlation);
2083        }
2084    }
2085
2086    // Calculate prediction variance for each sample
2087    let mut total_variance = 0.0;
2088    for sample_idx in 0..n_samples {
2089        let sample_predictions = individual_predictions.row(sample_idx);
2090        let mean_pred = sample_predictions.mean().unwrap();
2091
2092        let variance = sample_predictions
2093            .iter()
2094            .map(|&pred| (pred - mean_pred).powi(2))
2095            .sum::<f64>()
2096            / n_regressors as f64;
2097
2098        total_variance += variance;
2099    }
2100    let prediction_variance = total_variance / n_samples as f64;
2101
2102    // Bias-variance decomposition (simplified)
2103    let mut total_bias = 0.0;
2104    let mut total_variance_component = 0.0;
2105
2106    for sample_idx in 0..n_samples {
2107        let sample_predictions = individual_predictions.row(sample_idx);
2108        let mean_pred = sample_predictions.mean().unwrap();
2109        let true_val = true_values[sample_idx];
2110
2111        // Bias^2: squared difference between mean prediction and true value
2112        let bias_squared = (mean_pred - true_val).powi(2);
2113        total_bias += bias_squared;
2114
2115        // Variance: average squared difference from mean prediction
2116        let variance = sample_predictions
2117            .iter()
2118            .map(|&pred| (pred - mean_pred).powi(2))
2119            .sum::<f64>()
2120            / n_regressors as f64;
2121        total_variance_component += variance;
2122    }
2123
2124    Ok(RegressionDiversityMeasures {
2125        prediction_correlation: correlations.iter().sum::<f64>() / correlations.len() as f64,
2126        prediction_variance,
2127        average_bias: (total_bias / n_samples as f64).sqrt(),
2128        average_variance: total_variance_component / n_samples as f64,
2129        individual_rmse,
2130    })
2131}
2132
2133/// Calculate Pearson correlation coefficient between two arrays
2134fn calculate_pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> Result<f64> {
2135    if x.len() != y.len() || x.len() < 2 {
2136        return Err(SklearsError::InvalidInput(
2137            "Arrays must have same length and at least 2 elements".to_string(),
2138        ));
2139    }
2140
2141    let n = x.len() as f64;
2142    let mean_x = x.mean().unwrap();
2143    let mean_y = y.mean().unwrap();
2144
2145    let mut numerator = 0.0;
2146    let mut sum_sq_x = 0.0;
2147    let mut sum_sq_y = 0.0;
2148
2149    for i in 0..x.len() {
2150        let diff_x = x[i] - mean_x;
2151        let diff_y = y[i] - mean_y;
2152
2153        numerator += diff_x * diff_y;
2154        sum_sq_x += diff_x * diff_x;
2155        sum_sq_y += diff_y * diff_y;
2156    }
2157
2158    let denominator = (sum_sq_x * sum_sq_y).sqrt();
2159
2160    if denominator < 1e-10 {
2161        Ok(0.0) // No correlation if no variance
2162    } else {
2163        Ok(numerator / denominator)
2164    }
2165}