sklears_compose/ensemble/
common.rs

1//! Common utilities shared across ensemble implementations
2
3use scirs2_core::ndarray::{Array1, Array2, Axis};
4use sklears_core::types::Float;
5
6/// Activation functions for ensemble operations
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ActivationFunction {
9    /// ReLU
10    ReLU,
11    /// Sigmoid
12    Sigmoid,
13    /// Tanh
14    Tanh,
15}
16
17/// Fallback SIMD implementations for ensemble operations
18pub mod simd_fallback {
19    /// Add two vectors element-wise
20    pub fn add_vec(a: &[f32], b: &[f32], result: &mut [f32]) {
21        for i in 0..a.len() {
22            result[i] = a[i] + b[i];
23        }
24    }
25
26    /// Multiply two vectors element-wise
27    pub fn multiply_vec(a: &[f32], b: &[f32], result: &mut [f32]) {
28        for i in 0..a.len() {
29            result[i] = a[i] * b[i];
30        }
31    }
32
33    /// Divide two vectors element-wise
34    pub fn divide_vec(a: &[f32], b: &[f32], result: &mut [f32]) {
35        for i in 0..a.len() {
36            result[i] = a[i] / b[i];
37        }
38    }
39
40    /// Scale a vector by a scalar value
41    pub fn scale_vec(vec: &[f32], scale: f32, result: &mut [f32]) {
42        for i in 0..vec.len() {
43            result[i] = vec[i] * scale;
44        }
45    }
46}
47
48/// Ensemble prediction statistics computed using SIMD
49#[derive(Debug, Clone)]
50pub struct EnsembleStatistics {
51    pub mean: Float,
52    pub variance: Float,
53    pub confidence: Float,
54    pub diversity: Float,
55    pub bias: Float,
56    pub prediction_entropy: Float,
57    pub disagreement: Float,
58    pub average_confidence: Float,
59    pub min_confidence: Float,
60    pub max_confidence: Float,
61    pub std_confidence: Float,
62    pub skew_confidence: Float,
63    pub kurtosis_confidence: Float,
64    pub median_confidence: Float,
65    pub iqr_confidence: Float,
66    pub prediction_stability: Float,
67    pub convergence_rate: Float,
68    pub ensemble_complexity: Float,
69    pub overfitting_risk: Float,
70    pub generalization_error: Float,
71    pub calibration_score: Float,
72}
73
74impl Default for EnsembleStatistics {
75    fn default() -> Self {
76        Self {
77            mean: 0.0,
78            variance: 0.0,
79            confidence: 0.0,
80            diversity: 0.0,
81            bias: 0.0,
82            prediction_entropy: 0.0,
83            disagreement: 0.0,
84            average_confidence: 0.0,
85            min_confidence: 0.0,
86            max_confidence: 0.0,
87            std_confidence: 0.0,
88            skew_confidence: 0.0,
89            kurtosis_confidence: 0.0,
90            median_confidence: 0.0,
91            iqr_confidence: 0.0,
92            prediction_stability: 0.0,
93            convergence_rate: 0.0,
94            ensemble_complexity: 0.0,
95            overfitting_risk: 0.0,
96            generalization_error: 0.0,
97            calibration_score: 0.0,
98        }
99    }
100}
101
102impl EnsembleStatistics {
103    /// Compute statistics from predictions array
104    #[must_use]
105    pub fn from_predictions(predictions: &Array2<Float>) -> Self {
106        let mut stats = EnsembleStatistics::default();
107
108        if predictions.is_empty() {
109            return stats;
110        }
111
112        // Calculate mean and variance across ensemble predictions
113        let mean = predictions.mean().unwrap_or(0.0);
114        let variance = predictions.var(0.0);
115
116        stats.mean = mean;
117        stats.variance = variance;
118        stats.confidence = 1.0 - variance; // Simple confidence measure
119
120        // Calculate diversity as variance across predictors
121        let row_means: Array1<Float> = predictions.mean_axis(Axis(1)).unwrap();
122        stats.diversity = row_means.var(0.0);
123
124        stats
125    }
126}
127
128/// SIMD ensemble operations module
129pub mod simd_ensemble {
130    use super::ActivationFunction;
131
132    /// Apply activation function using SIMD operations
133    pub fn apply_activation_simd(values: &mut Vec<f32>, activation: ActivationFunction) {
134        match activation {
135            ActivationFunction::ReLU => {
136                for value in values {
137                    *value = value.max(0.0);
138                }
139            }
140            ActivationFunction::Sigmoid => {
141                for value in values {
142                    *value = 1.0 / (1.0 + (-*value).exp());
143                }
144            }
145            ActivationFunction::Tanh => {
146                for value in values {
147                    *value = value.tanh();
148                }
149            }
150        }
151    }
152}